retarfiの日記

自然言語処理などの研究やゴルフ、音楽など。

transformersのDataCollatorForWholeWordMaskについての覚書き

以前BERTやELECTRAを日本語で事前学習するリポジトリを作った
(https://github.com/retarfi/language-pretraining)のだが、
その際に参考にしたtransoformersのversionは4.7.2だった。
v4.7.2では、DataCollatorForWholeWordMaskの実装が間違っていたため、自分で書き直していた
(https://github.com/retarfi/language-pretraining/blob/v1.0/utils/data_collator.py#L49)。
今回バージョンアップにあたって再度v4.22.2の実装を見ていると、実用できそうな実装になっていた。 その思考過程をメモとして残しておく。

まずそもそものMaskingの概要として、BERTを例にとって説明する。
BERTは全tokenのうち15%をMaskingの対象とする。
ここで注意すべきはこれらが全て[MASK]トークンに置き換わるわけではないということ。
この15%のトークンのうち、更に80%(全体で15%x80%=12%)が[MASK]トークンに置き換わる。
残りの20%(全体の3%)のトークンのうち、更に10%(全体で1.5%)がRandom Replaceの対象に、
また最後に残った10%(全体の1.5%)がそのままのトークンとして残る(As it is)。
これら3種類のトークンが事前学習におけるMasked Language Modelという、もとの単語を予測するタスクの対象となる。
そして、Whole Word Mask(以下WWM)では、サブワードのトークンではなく、単語ごとにMaskを行うかの判定をする。
https://github.com/google-research/bertと同様に、以下に例示する。
例えば以下のInput Textに対して通常はその下のOriginal Masked Inputを与える。

Input Text: the man jumped up , put his basket on phil ##am ##mon ' s head
Original Masked Input: [MASK] man [MASK] up , put his [MASK] on phil [MASK] ##mon ' s head

ここで、phil / ## am / ##mon という1語に注目すると、真ん中の ##amのみが[MASK]トークンに置換されている。
これだともとのトークンを予測するのが簡単なため、WWMでは phil / ## am / ##mon の全トークンを[MASK]で置換する。

Whole Word Masked Input: the man [MASK] up , put his basket on [MASK] [MASK] [MASK] ' s head

ここで問題となるのは、ReplacedやAs it isなトークンも、1語単位で行うべきかという点である。

まず、以前の自分の実装では[MASK]トークンに置き換わる確率(上記では12%)を計算し、ランダムに1語ずつ積んでいって合計のトークン数が12%になったところまでを[MASK]トークンになるように置換している。
https://github.com/retarfi/language-pretraining/blob/v1.0/utils/data_collator.py#L118
一方、ReplacedやAs it isについては単語ではなくサブワード単位で行っている(つまり1語内でReplacedされるトークンとそうでないトークンが共存しうる)。

次に、transformersのv4.22.2の実装を見てみる。
v4.10頃から、PyTorch, TensorFlow, Numpyと分けて実装されているため見づらいが、ここではPyTorchに絞る。
https://github.com/huggingface/transformers/blob/v4.22.2/src/transformers/data/data_collator.py#L769
すると、masked_indicesで単語ごとにmaskingの対象にしたトークン列(masked_indices)のうちから、ランダムで80%を[MASK]トークンに置換する実装になっている。
これでは、1単語の中で[MASK]トークンに置換されるトークンとそうでないトークンが共存してしまう。
また、ReplacedやAs it isについてはmasked_indicesの残りから選ばれるようになっている。
そのため、[MASK], Replaced, As it isを合わせるとWWMが維持されるが、その中ではばらばらになってしまう。

最後に、本家であるgoogleの実装を見てみる。
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L369
この実装では、transformersのv4.22.2の実装と同様、まず単語ごとにMaskingの対象となるトークンを抽出する。
そこから(1単語ごとの制約を設けずに)各トークンについて80-10-10の割合で分けている。
そのため、この実装でも1単語内の[MASK], Replaced, As it isがばらばらになることとなる。

以上を踏まえると、transformersの現在の実装はgoogle本家と同じで正しいといえる。
自作のWWMは廃止するか、、、