Jupyter Notebookで時間のかかる機械学習やデータ解析の処理を走らせているとき、その経過状況をグラフにプロットして観察したい、ということがある。そういうとき、エポックごとに 1 からグラフ描画処理を普通に走らせると Notebook 上にグラフが「追記」されてしまう。これを表示しないようにする方法を今日は調べたので備忘録。
結論は簡単で、描画したグラフをファイル保存した後、matplotlib.pyplot.close()
関数に matploblib.figure.Figure
オブジェクトを渡してやれば Notebook 上でそのグラフが表示されるのを止められる。たとえば深層学習(ディープラーニング)の学習中に損失関数の値 (loss) をリアルタイムにプロットしたい場合、次のような感じになる。
for epoch in max_epoch: # 学習 losses, accuracies = train_and_validation() # 損失関数の値の推移をプロット fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=96) epochs = [str(e + 1) for e in range(epoch+1)] ax1.set_title("loss") ax1.set_xlabel("epoch") ax1.plot(epochs, losses) ax2.set_title("accuracy") ax2.set_xlabel("epoch") ax2.plot(epochs, accuracies) fig.tight_layout() fig.savefig("learning_curve.png") plt.close(fig)
上記の例では learning_curve.png
をエポックごとに上書き出力するので、これを画像ファイルを自動リロードしてくれる nomacs などで開いておくと、リアルタイムに状況を観察できる。
目的は状況の観察なので、Jupyter を前提にするなら汎用的に使える Live Loss Plot の方が使いやすいだろうし、TensorFlow ベースの機械学習なら TensorBoard などといったもっと良い代替手段がいくらでもあると思う。ただ、まあ、こういう単純な方法はつぶしが効く。一つ、こういう手段を覚えておくと「なんか TensorBoard が動かなくなって困った」などといった時に役に立つ…かもしれない。立たないかもしれないけれど。