kumilog.net

データ分析やプログラミングの話などを書いています。

KerasのGeneratorを自作する

Keras Advent Calendar 2017 の 25日目 の記事です。

Kerasでモデルを学習するmodel.fit_generator()でつかうgeneratorを自作してみます。なお、使用したKerasのバージョンは2.1.2です。

はじめに

Kerasでモデルを学習するには、

  • model.fit()
  • model.fit_generator()
  • model.test_on_batch()

のいずれかを使うと思います。

fit()は、データがメモリにのるくらいの規模のデータに向いており、fit_generator()は、generatorを使ってバッチ毎にデータを読み込みます。メモリ以上のデータを扱うときやバッチ毎に処理をさせるときに使います。test_on_batch()を使うと学習ループをカスタマイズすることができます。

ちなみに、こちらの比較記事によると、fit()fit_generator()の処理速度はほぼ同じとのことです。

今回はmodel.fit_generator()で使うgeneratorを作ってみます。画像を扱う際には、ImageDataGeneratorという便利なgeneratorがKerasでは用意されていますが、画像以外やちょっと複雑なことをやりたいときは、自分で準備したgeneratorを使うこともできます。

Generatorをつくる

Kerasのドキュメントには、fit_generator()の引数のgeneratorは以下のような説明があります。

  • generator: ジェネレータかマルチプロセッシング時にデータの重複を防ぐためのSequence(keras.utils.Sequence)オブジェクトのインスタンス.本ジェネレータの出力は,以下のいずれかです.
    • (inputs, targets)のタプル.
    • (inputs, targets, sample_weights)のタプル.すべての配列は同じ数のサンプルを含む必要があります.本ジェネレータは無期限にそのデータをループさせるようになっています.steps_per_epoch数のサンプルがモデルに与えられると1度の試行が終了します.

generatorかkeras.utils.Sequenceを引数とすればいいようです。今回はcifar10の画像を読み込むgeneratorをつくってみます。Pythonのgeneratorについては、こちらの記事が参考になります。

cifar10のデータはpickle形式で配布されていますが、以下よりpng形式のものをダウンロードできるので、こちらを使います。

$ wget http://pjreddie.com/media/files/cifar.tgz

trainとtestのディレクトリにpng形式の画像ファイルが格納されており、ファイル名は<id>_<label>.pngとなっています。

cifar
├── labels.txt
├── test
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...
├── train
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...

ディレクトリにある画像を読み込むgeneratorは以下のように書けます。

class ImageDataGenerator(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.images = []
        self.labels = []

    def flow_from_directory(self, directory, classes, batch_size=32):
        # LabelEncode(classをint型に変換)するためのdict
        classes = {v: i for i, v in enumerate(sorted(classes))}
        while True:
            # ディレクトリから画像のパスを取り出す
            for path in pathlib.Path(directory).iterdir():
                # 画像を読み込みRGBへの変換、Numpyへの変換を行い、配列(self.iamges)に格納
                with Image.open(path) as f:
                    self.images.append(np.asarray(f.convert('RGB'), dtype=np.float32))
                # ファイル名からラベルを取り出し、配列(self.labels)に格納
                _, y = path.stem.split('_')
                self.labels.append(to_categorical(classes[y], len(classes)))

                # ここまでを繰り返し行い、batch_sizeの数だけ配列(self.iamges, self.labels)に格納
                # batch_sizeの数だけ格納されたら、戻り値として返し、配列(self.iamges, self.labels)を空にする
                if len(self.images) == batch_size:
                    inputs = np.asarray(self.images, dtype=np.float32)
                    targets = np.asarray(self.labels, dtype=np.float32)
                    self.reset()
                    yield inputs, targets

Generatorをつかう

作ったgeneratorを使ってみます。fit_generator()の引数には、generatorのほか steps_per_epochが必要です。steps_per_epochはその名のとおり、1epochのstep数で、基本はデータ数 / バッチサイズになるかと思います。

train_dir = pathlib.Path('/path/to/train/')
train_datagen = ImageDataGenerator()
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

model.fit_generator(
    generator=train_datagen.flow_from_directory(train_dir, classes),
    steps_per_epoch=int(np.ceil(len(list(train_dir.iterdir())) / batchsize)),
    epochs=args.epoch,
    verbose=1)

testデータを評価するときは、validation_dataにgeneratorを、validation_stepsにvalidationでつかうstep数を指定します。

test_dir = pathlib.Path('/path/to/test/')
test_datagen = ImageDataGenerator()

model.fit_generator(
    generator=train_datagen.flow_from_directory(train_dir, classes),
    steps_per_epoch=int(np.ceil(len(list(train_dir.iterdir())) / batchsize)),
    epochs=args.epoch,
    verbose=1,
    validation_data=test_datagen.flow_from_directory(test_dir, classes),
    validation_steps=int(np.ceil(len(list(test_dir.iterdir())) / batchsize)))

おわりに

Kerasのmodel.fit_generator()でつかうgeneratorの作り方について説明しました。画像ファイルを読み込むだけであれば、自作する必要はないですが、KerasのImageDataGeneratorで用意されていない加工を行うときには使う機会がでてくると思います。

また、画像を加工する方法については以下の記事にまとめてあります。興味があれば一緒にどうぞ。

xkumiyu.hatenablog.com

今回、使用したコードは以下にあります。

github.com