kumilog.net

データ分析やプログラミングの話などを書いています。

Conditional DCGANで画像生成

GANの一種であるDCGANとConditional GANを使って画像を生成してみます。

GANは、Generative Adversarial Networks(敵性的生成ネットワーク)の略で、Generator(生成器)とDiscriminator(判別器)の2つネットワークの学習によって、ノイズから画像を生成するアルゴリズムです。

生成器Gは、判別器Dに本物と誤認識させるような画像を生成し、判別器Dは、本物か偽物かを見分ける役割があります。

f:id:xkumiyu:20170909021901p:plain
GAN

GANの仕組みについては、こちらの記事がとても参考になります。

また、GANは2014年にIan Goodfellow氏に提案してから多数の関連論文が発表されています。

f:id:xkumiyu:20170909115252j:plain 出典:GitHub - hindupuravinash/the-gan-zoo: A list of all named GANs!

DCGAN

GANの学習は難しいのですが、それを改善したのがDeep Convolutional GAN(DCGAN)です。DCGANでは、GANの2つのネットワークにCNN(Convolutional Neural Network)を用いることにより、学習の精度が上がり高解像度の画像が生成できるようになりました。

f:id:xkumiyu:20170909121521p:plain 出典:[1511.06434] Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

Conditional GAN

GANは、ランダムノイズから画像を生成するため、生成される画像を制御することができません。そこで、ラベルyを付与することにより条件付きモデルに拡張したのが、Conditional GAN(cGAN)です。最近、流行っているpix2pixなどもcGANのアイディアをベースにしています。

f:id:xkumiyu:20170909021908p:plain
Conditional GAN

Chainerによる実装

ChainerのExampleにあるDCGANをベースに、cGANのアイディアを組み込んでみました。実装についてはこちらの記事を参考にしました。

github.com

Generator

通常のGANのGeneratorの入力はn次元のノイズです。cGANではこれにラベルを加えるので、ノイズとラベルのベクトルを結合します。ノイズが100次元、ラベルが10次元であれば、110次元のベクトルをcGANのGeneratorの入力とします。

また、ラベルは、one-hot表現に変換して用います。例えば、0~9までの10個のラベルの場合、3というラベルは[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]のような配列に変換され、10次元のベクトルとなります。

Discriminator

通常のGANのDiscriminatorの入力は、width x heigh次元の画像です。同じようにラベルを結合できないので、ラベルを同じwidth x heigh次元の画像を変換します。変換はラベルの数だけ画像を作り、該当するラベルの画像のみ値を1とし、それ以外の画像の値を0とします。

例えば、0~9までの10個のラベルの場合、3のというラベルは、10枚の画像のうち4枚目の画像の値が1となります。正規化しているので画像の値は[0, 1]ですが、4枚目の画像が真っ白で、それ以外の9枚の画像が真っ黒になるイメージです。

作成されたラベル画像と、元々の入力画像を結合させます。入力画像が32 x 32のカラー画像(3ch)の場合、最終的な入力の次元(shape)は(3 + 10, 32, 32)となります。

MNISTデータの結果

60,000枚の手書き文字のデータ(MNIST)対してミニバッチサイズ100で20エポックまわしました。1GPUで2時間くらいでした。

$ python train.py -g 0 -e 20 -b 100

f:id:xkumiyu:20170910182802p:plain
MNIST 20epoch

今回は、(汎用的なコードを目指し)3チャンネルを使うものになっているので、若干緑がかっている気がします。

また、学習したモデルを利用して任意のラベルの画像を生成することができます。Githubには学習済みモデルもあげてあるので、試してみてください。CPUでも動きます。

$ python generate.py -m model/mnist-gen.npz -l 3

f:id:xkumiyu:20170910183219p:plain

ちなみに、GeneratorとDiscriminatorの誤差(loss)は以下のようになりました。

f:id:xkumiyu:20170910182805p:plain
loss

Discriminatorのlossは下がっていますが、Generatorのlossとむしろ上がっています。以下の記事のよると、うまくいっているときはDiscriminatorのlossが小さな分散で下がっているらしいので、うまくいっていると思われます。

qiita.com

※ 2017/12/12 追記 GeneratorとDiscriminatorのlossは均衡する方が良いらしいです。