神野さんに言われました。

神野さんに言われました。

AIの勉強をしています @sesenosannko

TensorflowのBatch Normalizationについてのメモ

TensorflowのBatch Normalizationでハマったのでメモしておきます。

あるモデルで訓練時はうまく学習されているのにテスト時にはひどい出力を出されれるということがありました。重み共有とかいろいろとしていたので原因を探すのに時間がかかったのですが、問題はBatch Normalizationでした。

TensorflowでBatch Normalizationを使用する方法は複数ありますが、もっとも簡単な方法はtf.layers.batch_normalizationを使う方法らしいです。

https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

is_trainingフラグによって訓練時(Trueのとき)には平均と分散の計算を行うのに対して、テスト時(Falseのとき)には学習された値を使うという制御を行っています。これがFalseになっているときに正しく動いていなかったようです。

原因は単純で、Batch Normalizationの変数が学習されていなかったようです。上のページに書いてある通り、Batch Normalizationの更新は明示的に指定しないといけないことになっています。面倒です。

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

僕はここに書いてある方法をそのまま書いていたのですが、なぜか学習されていませんでした。今も原因は分かっていません・・・

結局いろいろと調べている中で下の記事にある通りupdates_collections=Noneとしたら正しく動きました。このように指定すると上のような書き方をしなくてもその場で更新されるようになるらしいです。学習が遅くなると書いてある記事がありましたが、とりあえず僕のモデルは動いたので良いことにします。

結局原因は分かっていないのですが、Batch Normalizationは良く使うのでメモとして残しておきます。