Xgboost の train のラウンドごとに 関数を呼び出すための callbacks の利用法を備忘録として残しておく。
Callbackの利用例
ラウンドごとに日時を表示する例を作成した。
Gist https://gist.github.com/kunsen-an/f09966dc2ab7f3975f983091c3800fb5 にテストコードを含む notebook を置いた。実行結果例もそちらで確認できる。
callback を戻り値とする関数の定義
ここでは、return_callback()という関数を定義し、その中で定義したprint_time関数を返している。
train の evals で評価すべきデータセットを指定していて、評価結果がある場合には、ラウンド回数(iteration)、日時(dt_string)とともにその結果(evaluation_result_list)も出力する。
評価結果がない場合には、ラウンド回数(iteration)と日時(dt_string)を表示する。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# 日時を表示する関数 print_timeを戻り値とする return_callbackを定義する import datetime def return_callback(): def print_time(env): now = datetime.datetime.now() dt_string = now.strftime("%Y/%m/%d %H:%M:%S") i = env.iteration if env.rank != 0 or len(env.evaluation_result_list) == 0: print(i,dt_string) return msg = '\t'.join([str(x) for x in env.evaluation_result_list]) print(i,dt_string, msg) return print_time |
ラウンドごとに呼び出されるコールバックのリスト設定
print_time()の戻り値であるcallbackを、ラウンドごとに呼び出される callbackのリスト(ここでは、callbacks)に設定する。
1 2 3 4 |
# 日時を表示する関数をリスト callbacks にセットする callbacks=[ return_callback() ] |
trainの引数
xgboost.train の引数 callbacksに、上に示したリスト callbacksを設定して、呼び出す。
1 2 3 4 5 6 7 8 9 |
booster = xgboost.train(train_params, train_dmat, callbacks=callbacks, # evals=evals, ... の行をコメントアウトすれば、evalsがなく、 # 検証用データの評価値はcallback関数で表示されなくなる(そもそも値がない)。 evals=evals, evals_result=evals_result, # verbose_evalをTrueにすると evalsで指定された検証用データの評価値がラウンドごとに表示されるが # callback の出力との区別を明確にするために Falseにしている verbose_eval=False ) |
参考
次の記事が参考になった。
また、上に示した例では直接使っていないが、Callback APIのドキュメントも当然ながら参考になる。
https://xgboost.readthedocs.io/en/latest/python/python_api.html#callback-api