TensorFlowのtf.train.Saverを使うときに気を付けたこと

TensorFlowには学習の状況を保存するSaverという機能があります。この機能を使うことによって学習の一時停止を行ったり、学習済みの判別器を再利用したりできます。使い方をよく忘れるのでまとめておきます。

1. セッションを保存する

仮想的なコードを使用して説明を行います。大きな流れは次の通りです

  1. placeholderを生成する(placeholder : TensorFlow用の変数)
  2. セッションを生成する(session : TensorFlowの演算を実行する)
  3. 計算グラフを作成する
    • infer : 推論を行う計算グラフを生成
    • calc_loss : 推論結果と正解データの誤差を計算
    • train : 誤差をもとに最適化を行う
  4. 学習用データを読み込む
  5. epochの回数だけ学習を行う
  6. 学習後のセッションをckpt_pathのディレクトリに保存する

2.保存したセッションを呼び出す

保存したセッションの呼び出しを行うときは保存した時と同じplacaholderと計算グラフを定義しておく必要があります。

3.ダメなパターン

保存時に定義していたplaceholderと計算グラフが宣言されていない

保存したセッションの呼び出しといっても、呼び出せるのはあくまでセッションだけで、セッションを構成するplaceholderや計算グラフは宣言しておく必要があります。

4.使用例

保存したセッションを再利用するかどうかで条件分岐をさせて、再利用するときはセッションの読み込み/再利用しない場合は変数の初期化を行ったりします。

まとめ

tf.train.saverを初めて使ったときのことを思い出しながら書きました。当時、このsave/restoreの仕様を勘違いしたままコードを組んでいて、よくわからんエラーに3日くらい悩んでました。

よく使う機能の割には日本語の情報が少ないので誰かの手助けになれば幸いです。

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Scroll to top