Xgboost の train のラウンドごとに 関数を呼び出すための callbacks の利用法を備忘録として残しておく。
目次
Callbackの利用例
ラウンドごとに日時を表示する例を作成した。
Gist https://gist.github.com/kunsen-an/f09966dc2ab7f3975f983091c3800fb5 にテストコードを含む notebook を置いた。実行結果例もそちらで確認できる。
callback を戻り値とする関数の定義
ここでは、return_callback()という関数を定義し、その中で定義したprint_time関数を返している。
train の evals で評価すべきデータセットを指定していて、評価結果がある場合には、ラウンド回数(iteration)、日時(dt_string)とともにその結果(evaluation_result_list)も出力する。
評価結果がない場合には、ラウンド回数(iteration)と日時(dt_string)を表示する。
# 日時を表示する関数 print_timeを戻り値とする return_callbackを定義する
import datetime
def return_callback():
def print_time(env):
now = datetime.datetime.now()
dt_string = now.strftime("%Y/%m/%d %H:%M:%S")
i = env.iteration
if env.rank != 0 or len(env.evaluation_result_list) == 0:
print(i,dt_string)
return
msg = '\t'.join([str(x) for x in env.evaluation_result_list])
print(i,dt_string, msg)
return print_time
ラウンドごとに呼び出されるコールバックのリスト設定
print_time()の戻り値であるcallbackを、ラウンドごとに呼び出される callbackのリスト(ここでは、callbacks)に設定する。
# 日時を表示する関数をリスト callbacks にセットする
callbacks=[
return_callback()
]
trainの引数
xgboost.train の引数 callbacksに、上に示したリスト callbacksを設定して、呼び出す。
booster = xgboost.train(train_params, train_dmat,
callbacks=callbacks,
# evals=evals, ... の行をコメントアウトすれば、evalsがなく、
# 検証用データの評価値はcallback関数で表示されなくなる(そもそも値がない)。
evals=evals, evals_result=evals_result,
# verbose_evalをTrueにすると evalsで指定された検証用データの評価値がラウンドごとに表示されるが
# callback の出力との区別を明確にするために Falseにしている
verbose_eval=False
)
参考
次の記事が参考になった。
また、上に示した例では直接使っていないが、Callback APIのドキュメントも当然ながら参考になる。
https://xgboost.readthedocs.io/en/latest/python/python_api.html#callback-api