再帰的ニューラルネットワーク(RNN, LSTM)をTensorFlowを用いて実装しました#1

TensorFlowを用いた再帰的ニューラルネットワークのクラスを実装したので晒していきます。数回に分けてクラスの概要、機能、実装の説明をしようと思います。バグや実装の悪い部分などがあればTwitterやコメント欄にメッセージをくださると助かります。

 

#再帰的ニューラルネットワークとは

そもそもですが再帰的ニューラルネットワークとは時系列データの解析などに有用と考えられているニューラルネットワークのモデルの一つです。株価や気象データなどの解析へ応用できると考えられています。

  • 畳み込みニューラルネットワーク→空間的な解析に強い
  • 再帰的ニューラルネットワーク→時間的な解析に強い

僕は個人的にこんな感じで解釈しています。ただ専門ではないのでポエムだと思ってください。

再帰的ニューラルネットワークのクラス

RNNのためのクラス作成を通してやりたかったことは次の3つ

  1. 再帰的ニューラルネットワークのモデルの一つであるLSTMをTensorFlowを用いて実装すること
  2. TensorFlowの仕様の一部であるSaverを用いて、学習済みの判別器を再利用できるようにすること
  3. 判別器を学習させるトレーニングモード、判別器の評価を行う評価モード、未知のデータに対して予測を行う予測モードを実装すること

#構成

このクラスは本体となるRNNクラスとその他2つのクラスから構成されています。

  • RNN Class : 流し込まれたデータに対して学習や予測を担うクラス
  • Loss Class : 使用する損失関数をまとめたクラス
  • Layer Class : 使用する活性化関数をまとめたクラス
  • (学習用データ整形用の関数群)

#機能

  • Loss Class、Layer Classで定義した損失関数、活性化関数を選択可能
  • セッションの保存、再利用(学習の一時停止が可能)
  • 計算グラフを設定するためのパラメータセット(辞書型データ)を読み込むことで手軽にモデルの生成が可能

これ以外にも、モデルを変更する際に変更箇所がなるべく局所化するように構築しています。

#コード

#メンバ変数

  • weights : 隠れ層の重みのリスト
  • biases : 隠れ層のバイアスのリスト
  • input_dim : 入力の次元数
  • output_dim : 出力の次元数
  • hidden_dims : 多層化した際のそれぞれの隠れ層の次元数を格納するリスト、今回の隠れ層は1層のLSTMなので要素数は1
  • actvs : 多層化した際のそれぞれの隠れ層の活性化関数を格納するリスト、今回は1層のLSTMなので使用していない
  • tau : 時系列データの長さ(時間的な長さ)
  • loss : 呼び出したい損失関数の名前
  • _x : 時系列データ
  • _t : 正解データ
  • _y : 外部から読み込んだ計算グラフを格納
  • _keep_prob : ドロップアウトの際に何%のニューロンを保持するか決めるパラメータ、今回は未使用
  • _log : 損失の値や正解率などを格納するディクショナリ
  • _sess : TensorFlowのセッションのインスタンス
  • ckpt_dir : セッションを一時保存しているチェックポイントが格納されているディレクトリ
  • ckpt_path : チェックポイントのパス
  • is_load : 一時保存したセッションを読み込むか(True or False)
  • is_save : 学習結果をを一時保存するか(True or False)

#関数

計算グラフを生成する関数

  • infer : 推論を行う計算グラフを生成する。
  • calc_loss : 損失を計算する
  • train : 損失をもとに計算グラフをトレーニングする
  • define_placeholder : 計算グラフに使用するplaceholder(TF専用の変数)を定義する。

判別器(計算グラフ)をトレーニングする関数

  • fit : 計算グラフを呼び出し、設定したエポック数分学習を行う

学習済みの判別器(計算グラフ)で実際に予測を行う関数

  • preload_model : 学習済みの計算グラフをloadする
  • predict : 学習済みの計算グラフを用いて予測を行う

セッションを一時保存(読み込み)を行う関数

  • save_session : セッションを保存する
  • restore_session : セッションを読み込む

セッションの再読み込み時には、セッションを保存したときに使用していたplaceholderを同じ名前であらかじめ宣言しておく必要がありますが、このクラスはその点を意識せず使用することができます。

 

その他の関数は詳しい機能の紹介がてら少しずつ実装を説明しようと思います。

 

実用例

#呼び出し(main関数)

実際にmain関数での呼び出しをもとに順に追って説明をしていきます。今回は判別器を学習させるためのモードであるtrainモードを例にします。

#計算グラフと学習データの設定

始めにlayer_dic, preproc_dic, candle_dicの3つの辞書型データについて説明します。

  • layer_dic : 層の設定に必要な入出力データの次元数や隠れ層の設定をする
  • preproc_dic : 学習用データに対して、正規化や標準化を行うかどうか設定する
  • candle_dic : 学習用の時系列データに対して、学習させる時間幅(τ)やどれだけ先の未来を予測するか(shift)などの設定を行います。今回は為替データを使って実験をするため、解析対象のローソク足の情報も入っています

#RNNクラスのインスタンスの生成

次にmodelというインスタンス名でRNNクラスを呼び出します。RNNクラスのコンストラクタは2つの引数を取ります

  • load : 保存済みのモデルをロードするか
  • save : 今回の実行結果を保存するか

関係ないですが、モデルの一時保存をする実装は情報量が少なくて大変だったので褒めてほしいです…。

#計算グラフと学習データの設定をインスタンスに反映

先ほどのlayer_dicとcandle_dicの設定を以下のメソッドを使ってニューラルネットワークのモデルに反映させる。

  • init_layer_param : layer_dicで設定したパラメータをインスタンス変数に設定する
  • init_loss_param : 損失関数で使う計算方法を設定する。’mse’は平均二乗誤差。損失関数の実装状況は後程紹介するLoss Classで詳しく
  • define_placeholder : init_layer_paramで設定したインスタンス変数をもとにplaceholderを生成。書いている途中で気づきましたが、呼び出しの順番を変えるとたぶんバグるので注意してください。

ここまでで計算グラフに必要なパラメータの設定は終了です。

#時系列データの読み込み

時系列データを読み込みRNNクラスに流し込める形に整形する関数としてload_candle関数を使用します。この関数の実装は使用するデータにより変わるため次回以降に詳細に解説を行います。

#モード選択

次にモードです。RNNクラスを運用するために4つのモードを作りました

  • train : 判別器をトレーニング+評価するモード
  • eval : 学習済みの判別器を評価するモード
  • construct : ニューラルネットワークの生成だけを行うモード、デバッグ用
  • predict : 学習済みの判別器を用いて未知のデータに対して予測を行うモード

今回はtrainモードをもとに説明していきます。

trainモードでは3つのパラメータを設定します

  • batch_size : バッチサイズ
  • keep_prob : 今回は使いません(dropoutは行いません)
  • epochs : エポック数、ループの回数です

#fit関数・evaluate関数

RNNクラスのメソッドであるfit関数がニューラルネットワークのトレーニングを行います。また、evaluate関数でトレーニングしたモデルの評価を行います。

#その他

NNvisualizer Classは解析結果やログの可視化用のクラスです。今回は詳しい紹介は割愛します

結果(為替データの予測結果)

ちなみに実行結果を可視化したものがこちらになります

赤線が真の値で青線が予測したものです。このモデルはあまり精度が良くないみたいです…。

こちら異なるモデルなのですが、ほぼ完ぺきに予測をしています。

損失の時間変動。振動しながら徐々に収束しています

いわゆるハイパーパラメータの設定により予測精度は変わってきます。この部分は現在進行形でいろいろ検証しています。

#まとめ

ひさびさにこんなに文章を書きました…。

今回はとりあえずクラスを作った報告+クラスの概要をつかんでもらえればいいかな…と思いRNNクラスのざっくりとした説明と実行例をもとにまとめてみました。

次回以降はさらに詳細な実装やTensorFlowの使い方に踏み込んでいき、最終的にはデータの解析までやろうと思っています。

とりあえずの予定としては、LSTMを用いて推論を行うinfer関数の説明、tf.train.Saver(セッションを一時保存する機能)の説明を2,3回に分けてするつもりです。ぜひ楽しみにしていてください。

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Scroll to top