kumilog.net

データ分析やプログラミングの話などを書いています。

ChainerでFineTuning その1

機械学習がうまくいくケースにおいて、教師あり学習の次に、転移学習が成功のカギになってくると言われているそうです。転移学習は、あるドメインで学習させたモデルを別のドメインに適用させる学習のことです。少ないデータでもうまく学習ができたりします。

転移学習の詳細は以下の記事が参考になります。

qiita.com

今回は、Chiner自身に実装されている関数を用いてFineTuningをやってみます。

学習済みモデルを使う

VGGで最後の全結合層(fc8)のみを変更した例が以下になります。初回起動時に、学習済みのモデルがダウンロード(と変換が)されデフォルトでは$HOME/.chainer/dataset/pfnet/chainer/models/に格納されます。

ダウンロードにネットワーク環境にもよりますが、そこそこ時間がかかります。私の環境では10分以上かかった気がします。また、モデルファイルは約500MBとそれなりの大きさです。

import chainer
import chainer.links as L


class VGG(chainer.Chain):
    def __init__(self, class_labels=1000):
        super(VGG, self).__init__()

        with self.init_scope():
            self.base = L.VGG16Layers()
            self.fc8 = L.Linear(4096, class_labels)

    def __call__(self, x):
        h = self.base(x, layers=['fc7'])['fc7']
        return self.fc8(h)

L.VGG16Layers()で学習済みのVGG16のモデルを呼び出しており、self.base(x, layers=['fc7'])['fc7']で1層目の畳込み層から15層目(fc7)まで流したときの出力が得られます。そして、その後に自分で作ったfc8を追加しています。

重みを固定する

この例では、学習済みの重みが初期値として用いますが、重みが固定されているわけではないので、このまま学習を回すと重みが更新されます。

そこで、重みを固定するためにdisable_update()を用います。disable_update()は、optimizer.setup()を呼び出したあとに設定します。

disable_update()はChain単位やLink単位で用いることができるので、baseの重みをまるごと固定します。

vgg = VGG()
optimizer = chainer.optimizers.Adam()
optimizer.setup(vgg)
vgg.base.disable_update()

ちゃんと固定されているか調べてみます。

def check_weight(freeze_params=False):
    vgg = VGG()
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(vgg)

    if freeze_params:
        vgg.base.disable_update()

    x = np.random.rand(3*224*224).reshape(1, 3, 224, 224).astype(np.float32)
    t = np.array([0], np.int32)
    y = vgg(x)

    print('[before] conv1_1.W mean: {}, min: {}, max: {}'.format(
        vgg.base.conv1_1.W.data.mean(),
        vgg.base.conv1_1.W.data.min(),
        vgg.base.conv1_1.W.data.max()))

    loss = F.softmax_cross_entropy(y, t)
    vgg.cleargrads()
    loss.backward()
    optimizer.update()

    print('[after] conv1_1.W mean: {}, min: {}, max: {}'.format(
        vgg.base.conv1_1.W.data.mean(),
        vgg.base.conv1_1.W.data.min(),
        vgg.base.conv1_1.W.data.max()))


print('# 重みを固定しない')
check_weight(False)

print('# 重みを固定する')
check_weight(True)
# 重みを固定しない
[before] conv1_1.W mean: -0.0024379089009016752, min: -0.6714000701904297, max: 0.6085159182548523
[after] conv1_1.W mean: -0.0025848273653537035, min: -0.670400083065033, max: 0.609515905380249
# 重みを固定する
[before] conv1_1.W mean: -0.0024379089009016752, min: -0.6714000701904297, max: 0.6085159182548523
[after] conv1_1.W mean: -0.0024379089009016752, min: -0.6714000701904297, max: 0.6085159182548523

最初の層であるconv1_1のみ調べましたが、最小値/最大値/平均値が全く同じ値になっているので、正しく重みが固定されていることが分かります。

大きなデータセットでFineTuning

スタンフォード大のオンライン講座によると、大きなデータセットでは、固定しない層(学習する層)を最後の層だけでなく、最後の層に近い順にいくつか学習させる層を増やすと精度が向上すると述べられています。

f:id:xkumiyu:20170928031154p:plain 出典:CS231n: Convolutional Neural Networks for Visual Recognition Lecture 7

学習させる層を増やすために、今回、用いているVGGの層を調べています。

>>> L.VGG16Layers().available_layers
['conv1_1', 'conv1_2', 'pool1', 'conv2_1', 'conv2_2', 'pool2', 'conv3_1', 'conv3_2', 'conv3_3', 'pool3', 'conv4_1', 'conv4_2', 'conv4_3', 'pool4', 'conv5_1', 'conv5_2', 'conv5_3', 'pool5', 'fc6', 'fc7', 'fc8', 'prob']

先程は、fc8のみ学習させる層としたので、今回はfc6, fc7, fc8も学習させてみます。

class VGG(chainer.Chain):
    def __init__(self, class_labels=1000):
        super(VGG, self).__init__()

        with self.init_scope():
            self.base = L.VGG16Layers()
            self.fc6 = L.Linear(512 * 7 * 7, 4096)
            self.fc7 = L.Linear(4096, 4096)
            self.fc8 = L.Linear(4096, class_labels)

    def __call__(self, x):
        h = self.base(x, layers=['pool5'])['pool5']
        h = F.dropout(F.relu(self.fc6(h)))
        h = F.dropout(F.relu(self.fc7(h)))
        return self.fc8(h)

用意されている学習済みモデル

ChainerのLink関数として用意されているモデルは、VGG以外に、GoogLeNet、ResNetがあります。詳細は公式ドキュメントを確認ください。

ただし、いずれのモデルも入力画像のサイズは224 x 224にする必要があります。リサイズや平均値を引くといった前処理を行う関数が用意されているので、利用すると良いでしょう。

import chainer.links as L
image = L.model.vision.vgg.prepare(image, size=(224, 224))

用意されていないモデルは、Caffeモデルから変換することで利用できます。Caffeモデルからの変換についてはまた今度、記事にしようと思います。書きました。

xkumiyu.hatenablog.com