Kerasのmodel.fit_generator()にSequenceをつかってみます。
はじめに
Kerasのfit_generator()の引数にはGeneratorかSequenceをつかうことができます。
今回はSequenceを使ってみます。SequenceはChainerのDatasetMixinと同じような感じで書けます。また、Generatorは、以下の記事を参照ください。
Sequenceをつくる
Sequenceは、keras.utils.Sequenceを継承してつくります。__getitem__メソッドと__len__が必須です。
__len__メソッドは1epochあたりのイテレーション回数です。通常は、サンプル数をバッチサイズで割った値(の切り上げ)になります。__getitem__メソッドは、0から__len__メソッドが返す値までの値(idx)を受け取り、バッチを返す必要があります。
必須ではありませんが、epoch毎に処理したい場合on_epoch_endメソッドを実装することができます。
画像をファイルから読み込むSequenceをChainerのDatasetっぽく作ってみました。pairsは画像ファイル名とラベルのタプルのリストです。
pairs = [
('image1.png', 0),
('image2.png', 1),
...
]
from pathlib import Path import math from skimage.io import imread from keras.utils import Sequence from keras.utils import np_utils import numpy as np class ImageSequence(Sequence): def __init__(self, pairs, num_classes, root='.', batch_size=1): self.x = [str(root / Path(x)) for x in pairs[0]] self.y = np_utils.to_categorical(pairs[1], num_classes) self.batch_size = batch_size def __getitem__(self, idx): # バッチサイズ分取り出す batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size] batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size] # 画像を1枚づつ読み込んで、前処理をする batch_x = np.array([self.preprocess(imread(file_name)) for file_name in batch_x]) return batch_x, np.array(batch_y) def __len__(self): return math.ceil(len(self.x) / self.batch_size) def preprocess(self, image): # いろいろ前処理 return image
前処理では、画像をCropしたりリサイズしたり加工する処理を自由に加えることができます。基本的な画像処理であれば、わざわざSequenceをつくらずKerasが用意するImageDataGeneratorを使うと良いと思います。
Scale Augmentationのように単純なリサイズではない処理の場合、ImageDataGeneratorでは対応できなかったので、Sequenceを自作したりしています。
Scale Augmentationやその他の画像の前処理については以下の記事も合わせてどうぞ。
ChainerのDatasetMixinとの違い
ChainerのDatasetMixinは、get_exampleメソッドと__len__メソッドを用意します。それぞれKerasの__getitem__メソッドと__len__メソッドに対応しますが、少し扱いが異なるので注意が必要です。
| Keras | Chainer | |
|---|---|---|
__getitem__ / get_example の引数 |
バッチを表す添字 | サンプルを表す添字 |
__getitem__ / get_example が返す値 |
バッチ | サンプル |
__len__が返す値 |
1epochのバッチ数 | サンプル数 |
ChainerのDatasetMixinは以下の記事にまとめています。
Sequenceをつかう
使い方は、model.fit_generatorの引数につくったSequenceを与えるだけです。学習データ用だけでなく、検証データ用を用意する場合も、そのまま使えます。
train_gen = ImageSequence(train_pairs, num_classes, batchsize)
valid_gen = ImageSequence(valid_pairs, num_classes, batchsize)
model.fit_generator(
generator=train_gen,
epochs=epoch,
steps_per_epoch=len(train_gen),
verbose=1,
validation_data=valid_gen,
validation_steps=len(valid_gen))