matplotlibで描いたグラフをファイル保存しつつJupyterでは表示させない方法

Jupyter Notebookで時間のかかる機械学習やデータ解析の処理を走らせているとき、その経過状況をグラフにプロットして観察したい、ということがある。そういうとき、エポックごとに 1 からグラフ描画処理を普通に走らせると Notebook 上にグラフが「追記」されてしまう。これを表示しないようにする方法を今日は調べたので備忘録。

結論は簡単で、描画したグラフをファイル保存した後、matplotlib.pyplot.close() 関数に matploblib.figure.Figure オブジェクトを渡してやれば Notebook 上でそのグラフが表示されるのを止められる。たとえば深層学習(ディープラーニング)の学習中に損失関数の値 (loss) をリアルタイムにプロットしたい場合、次のような感じになる。

for epoch in max_epoch:
    # 学習
    losses, accuracies = train_and_validation()

    # 損失関数の値の推移をプロット
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=96)
    epochs = [str(e + 1) for e in range(epoch+1)]
    ax1.set_title("loss")
    ax1.set_xlabel("epoch")
    ax1.plot(epochs, losses)
    ax2.set_title("accuracy")
    ax2.set_xlabel("epoch")
    ax2.plot(epochs, accuracies)
    fig.tight_layout()
    fig.savefig("learning_curve.png")
    plt.close(fig)

上記の例では learning_curve.png をエポックごとに上書き出力するので、これを画像ファイルを自動リロードしてくれる nomacs などで開いておくと、リアルタイムに状況を観察できる。

f:id:sgryjp:20190707212147g:plain

目的は状況の観察なので、Jupyter を前提にするなら汎用的に使える Live Loss Plot の方が使いやすいだろうし、TensorFlow ベースの機械学習なら TensorBoard などといったもっと良い代替手段がいくらでもあると思う。ただ、まあ、こういう単純な方法はつぶしが効く。一つ、こういう手段を覚えておくと「なんか TensorBoard が動かなくなって困った」などといった時に役に立つ…かもしれない。立たないかもしれないけれど。