retarfiの日記

自然言語処理などの研究やゴルフ、音楽など。

HuggingFaceのtransformers.trainerをDeepSpeedと一緒に使うときの注意覚書

事前学習関連で色々試していたらHuggingFaceのtransformersとDeepSpeedのIntegrationでうまくいかないところがあった。 具体的には、transformers.TrainerとDeepSpeedを同時に使っていて、さらにgraidient_accumulation_stepsが1でない場合に、transformers.TrainerとDeepSpeedのそれぞれのglobal step数がずれてしまうというものである。

transformersのv4.24.0ではL1767以降でoptimizerの更新を行う。

https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/trainer.py#L1767-L1771

しかし、DeepSpeedを使っている場合にはここでは更新が行われない(L1801,1802)。 ではどこで更新しているかというと、L1767-L1771のif文の外側(L1764,1765)である。 このとき、epochの真ん中で回っている分には問題はないのだが、transformersではepochの終わりに必ずoptimizerを更新し、global stepを進めてしまう(L1770)。 しかし、DeepSpeedにはepochなんてものはわからない(各ステップ与えられたデータを淡々と処理する)ので、epochの最後だろうが(step + 1) % args.gradient_accumulation_steps != 0であればoptimizerは更新されない。

簡単な例として、gradient_accumulation_stepsが2で1 epochでは3回forwardがされるとする。 HuggingFaceのみ(正常)の場合は、以下のように動く。 count-global_steps(HF)-epoch-steps(HF) 0-0-0-0 1-0-0-1(optimizer.step()) 2-1-0-2(optimizer.step()) 3-2-1-0 4-2-1-1(optimizer.step()) 5-3-1-2

しかし、DeepSpeedを使うと2列目と5列目がずれていることがわかる。 count-global_steps(HF)-epoch-steps(HF)-global_steps(DS)-steps(DS) 0-0-0-0-0-0 1-0-0-1-0-1(optimizer.step()) 2-1-0-2-1-0(epochの最後なのでHuggingFaceではglobal_stepsが1増えるが、DeepSpeedではoptimizer.step()は行われない) 3-2-1-0-1-1(optimizer.step()) 4-3-1-1-2-0 5-3-1-2-2-1(optimizer.step())

一応これはTrainingArgumentsのdataloader_drop_lastをTrueにすることで解決できると思われる。

当該箇所を解決するには、DeepSpeedを利用しない場合についてのif文を追加してepochの最後に関する条件を

if (not self.deepspeed and ((step + 1) % args.gradient_accumulation_steps == 0 or (
    # last step in epoch but step is always smaller than gradient_accumulation_steps
        steps_in_epoch <= args.gradient_accumulation_steps
        and (step + 1) == steps_in_epoch
    ))) or (self.deepspeed and self.deepspeed.tput_timer.total_step_count % args.gradient_accumulation_steps == 0):

などとすると良いと思う。 多分、gradient_accumulation_stepsの境目かの判定でself.deepspeed.tput_timer.total_step_countよりもTrainer側のself.state.global_stepを使ったほうが良いかも。 この辺ちゃんと検証したら、transformersにプルリク出してみたいとは思う(OSSプルリクしたことないが)。