kumilog.net

データ分析やプログラミングの話などを書いています。

KerasのGeneratorにSequenceをつかう

Kerasのmodel.fit_generator()にSequenceをつかってみます。

はじめに

Kerasのfit_generator()の引数にはGeneratorかSequenceをつかうことができます。

今回はSequenceを使ってみます。SequenceはChainerのDatasetMixinと同じような感じで書けます。また、Generatorは、以下の記事を参照ください。

www.kumilog.net

Sequenceをつくる

Sequenceは、keras.utils.Sequenceを継承してつくります。__getitem__メソッドと__len__が必須です。

__len__メソッドは1epochあたりのイテレーション回数です。通常は、サンプル数をバッチサイズで割った値(の切り上げ)になります。__getitem__メソッドは、0から__len__メソッドが返す値までの値(idx)を受け取り、バッチを返す必要があります。

必須ではありませんが、epoch毎に処理したい場合on_epoch_endメソッドを実装することができます。

画像をファイルから読み込むSequenceをChainerのDatasetっぽく作ってみました。pairsは画像ファイル名とラベルのタプルのリストです。

pairs = [
    ('image1.png', 0),
    ('image2.png', 1),
    ...
]
from pathlib import Path
import math

from skimage.io import imread
from keras.utils import Sequence
from keras.utils import np_utils
import numpy as np

class ImageSequence(Sequence):
    def __init__(self, pairs, num_classes, root='.', batch_size=1):
        self.x = [str(root / Path(x)) for x in pairs[0]]
        self.y = np_utils.to_categorical(pairs[1], num_classes)
        self.batch_size = batch_size

    def __getitem__(self, idx):
        # バッチサイズ分取り出す
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        # 画像を1枚づつ読み込んで、前処理をする
        batch_x = np.array([self.preprocess(imread(file_name)) for file_name in batch_x])
        return batch_x, np.array(batch_y)

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def preprocess(self, image):
        # いろいろ前処理
        return image

前処理では、画像をCropしたりリサイズしたり加工する処理を自由に加えることができます。基本的な画像処理であれば、わざわざSequenceをつくらずKerasが用意するImageDataGeneratorを使うと良いと思います。

Scale Augmentationのように単純なリサイズではない処理の場合、ImageDataGeneratorでは対応できなかったので、Sequenceを自作したりしています。

Scale Augmentationやその他の画像の前処理については以下の記事も合わせてどうぞ。

www.kumilog.net

ChainerのDatasetMixinとの違い

ChainerのDatasetMixinは、get_exampleメソッドと__len__メソッドを用意します。それぞれKerasの__getitem__メソッドと__len__メソッドに対応しますが、少し扱いが異なるので注意が必要です。

Keras Chainer
__getitem__ / get_example の引数 バッチを表す添字 サンプルを表す添字
__getitem__ / get_example が返す値 バッチ サンプル
__len__が返す値 1epochのバッチ数 サンプル数

ChainerのDatasetMixinは以下の記事にまとめています。

www.kumilog.net

Sequenceをつかう

使い方は、model.fit_generatorの引数につくったSequenceを与えるだけです。学習データ用だけでなく、検証データ用を用意する場合も、そのまま使えます。

train_gen = ImageSequence(train_pairs, num_classes, batchsize)
valid_gen = ImageSequence(valid_pairs, num_classes, batchsize)
model.fit_generator(
    generator=train_gen,
    epochs=epoch,
    steps_per_epoch=len(train_gen),
    verbose=1,
    validation_data=valid_gen,
    validation_steps=len(valid_gen))