Retentive Networkについてのメモ
Microsoft Researchから、Retentive Network が提案された(2023年7月)。 興味深いので、その意義についてざっくり理解したことを、つらつらとメモしておく。
(諸々正確でない点は、ご容赦いただきたい。)
[1] Yutao Sun et al., Retentive Network: A Successor to Transformer for Large Language Models, 2023.
[2307.08621] Retentive Network: A Successor to Transformer for Large Language Models
Retention
Retentive Networkの中核を担うのが、Retention機構である。
初め、「Retentive=[訳]保持力のある」って何かと思ったが、 Attention の後継を意識しての Retention だと思われる。
現在、LLM (大規模言語モデル)の技術としては、Transformer (Attention) が席巻している。
後述するが、学習を並列的に行えたことも1つの要因となって、高い精度を達成している。
一方、文の長さ N に対して、計算量が大きく、長文を扱いにくい面がある。 この点、実は、ひと世代前のRNNの方が優れている点である。
Retentionの肝は、同じ単純な再帰式からの式変形によって、 RNN的にもAttention的にも表現(計算)できる機構を作ったことにある。
導出は論文の式(1)から(7)を参照。
Figure 3を引用する。
以下、Nは文の長さとし、便宜上、トークン (d次元)を "単語"と呼ぶことにする。
(a) 並列表現 (=Attention的な表現)
- Q, K, Vで構成されていて、Transformerに明るい人は、Self-Attentionでよく見る構図だと思う。
- Q, K, Vはサイズ N×d であり、入力の単語列 X = (X_1 ... X_N) から作る。
Attentionと同様、Q, K, Vは単語単位で並列計算ができる。 - D は、Transformerにおいて先読みを防ぐためのmaskと、再帰の時に掛ける定数 γ を兼ねたような行列である。
- 図では表現されていないが、Q_nとK_nには $ e ^ {inθ} $ を乗算しておく。
- これは、何単語目であるか、という、再帰の代わりの情報(の一部)に相当する。
- これが、結果的に、TransformerでいうところのPosition Encodingの派生形(xPosというらしい)に類した形になっているらしく、面白いところ。
(b) 再帰表現 (=RNN的な表現)
- Q_n, K_n, V_n は、(a)における Q, K, V の n行目に対応する。
- 入力 X_n と中間状態 S_{n-1} さえあれば、n 番目より前の入力単語を全く使わずに、出力 O を計算できる。
(a)と(b)でQ, K, Vを共有できているので、学習時には (a) で並列的に学習、推論時には (b) で再帰的に推論、 と、文字通り良いとこ取りができる。
論文では、更に、(a)と(b)のハイブリッド型((c) Chunkwise表現)を提案し、学習にはこれを用いている。
例えば、長さN=4096の文を、チャンク長 C=512 で処理する場合。
文をチャンクに分け、各チャンクは並列計算した上で、得られる中間状態を次のチャンクに再帰的に渡すようにする。
こうすると、メモリ計算量も大きく削減できる(C2 << N2)。
ここでは説明しないが、以上から更に手を加えて、最終的な提案法としている。
- Transformerで使われる、Multi-head化。
- head毎に異なるγの値(定数)を与えることによる、マルチスケール化対応。
- 正規化による計算の安定化。(QKT → QKT/√d )。など。
学習結果のグラフなどは論文を見て頂ければと思うが、計算量O(1)の威力はすごい。
RNNについての補足
RNNは、推論効率、メモリ効率が高い。これの意味することを説明しておく。 例として、文書生成のようなタスクで、「I read the book.」を生成することを 考える。
初期状態 s0 を持っておいて、
① model(s0, <s>) → (s1, "I") ② model(s1, "I") → (s2, "read") ③ model(s2, "read") → (s3, "the") ④ model(s3, "the") → (s4, "book") ⑤ model(s4, "book") → (s5, ".")
ここで、<s>
は、最初であることを意味する特殊なトークン(ベクトル)。
これは、メモリ効率的には非常に優れている。 入力と s_{n-1} の2つのベクトルさえあれば、次の出力が計算できる。 これは、文の長さに依存しない。
一方、再帰計算であるからには、 上記を並列的に計算するのは困難である。 ⑤の入力に必要な s4 は、仮に全て正解を持っていたとしても、 ①、②、③、④と、順番に計算しない訳にはいかない。
Transformerについての補足
先ほどのRNNの説明の裏返しが、Transformerである。
Transformerは、中間状態を持たずに、現在の単語までの情報を 辞書的に拾ってくる、という戦略を取る。
先ほどの書き方に倣うと、次のようになる。
① model(<s>, -, -, -, -) → ("I", -, -, -, -) ② model(<s>, "I", -, -, -) → ("I", "read", -, -, -) ③ model(<s>, "I", "read", -, -) → ("I", "read", "the", -, -) ④ model(<s>, "I", "read", "the", -) → ("I", "read", "the", "book", -) ⑤ model(<s>, "I", "read", "the", "book") → ("I", "read", "the", "book", ".")
「中間状態を入力する必要が無い」ということがポイントで、 正解が分かっている状態(つまり、学習時)であれば、 ①~⑤をそのまま並列に処理できる。
ただし、この並列性は、推論時は役立たないので、1単語ずつ、入力を増やしながら 再帰的に推論していくしかない。
更に悪いことに、n 番目の単語を予測するときに n - 1 の辞書引きをするために、 入力の長さの2乗に比例するテーブルを作る必要がある。 特に、文の長さ n が増大すると、GPUメモリ不足も深刻になってくる訳である。
docopt は、docopt-ng として生き続けている、という話
Pythonでコマンドラインパーサと言ったら何を思い浮かべるだろう。
標準ライブラリ argparse の他、clickやfire などいろいろ選択肢があるが、私は docopt をずっと気に入って使っている。
で、docoptの(元祖)サイトを見に行くと、更新が止まっていてどうしたものか…と思ってしまうが、心配無用。
結論としては、docopt
は、docopt-ng
に引き継がれて
メンテナンスが続いているよ、という話。
日本語で言及している人が少なそうなので、記事にしておく。
docoptの概要
ごく簡単に、docoptを紹介しておく。
docoptは、「規則に従って書いたヘルプメッセージから、 コマンドラインパーサを作る」という、通常とは逆転の発想を 持ったライブラリである。
具体的な書き方の例を示す。pythonスクリプトの冒頭のコメントに、Usageと、オプションの説明を書けば良い。
### 'Usage:'から空行まで、Usageパターンを認識 ### """ Usage: my_program.py <file> [options] ### 空白除いて '-' から始まる行は、オプションとして認識 ### # 空白1個区切りで短いオプションと長いオプション。 # 空白2個以上の後に、説明を記載。 # [default: ]によってデフォルト値指定。 Options: -p --port=<n> 何らかの番号 -o <dir> 出力ディレクトリの指定 [default: ./output] --quiet 静かに実行 """ from docopt import docopt args = docopt(__doc__) # print(args)
このように書いておけば、あとは、args["<file>"]
として文字列を拾ったり、
args["--quiet"]
としてboolを拾ったりすることができる。
インストールは、pipで。pip install docopt-ng
とする。
その他、詳細はいろいろあるので、GitHubページを参照のこと。
なお、docopt自体はこれ以上の機能を持たないシンプルな設計なので、
型変換や範囲チェックなどのバリデーションは、(必要ならschema
などを用いて)別途行う。
docoptからdocopt-ngに至る経緯
私の記憶と断片的な情報をもとに書いているので、全く以って正確でないとは思うが、記しておく。
オリジナルのdocoptは、PyCon UK 2012で出現して、注目を浴びた。 その発想から、他の様々な言語にも移植された。
そんなdocoptだが、(少なくとも)元の Python実装は、GitHub2018年でメンテが止まっていて久しい。
一方、それより前か後か、docoptからForkされた改良版の1つ、 docopt-ng が存在しており、それが jazzband という、Python系プロジェクトの維持を目的とした (メタ) プロジェクトのメンテナンス下に収まったようである。
ちなみに、メンテナンス版(docopt-ng
)の方が docopt
の名前を踏襲できないか、という issue に挙がったりしているけれど、先例 (PIL→pillowとか) に見られるように、名前を変えるのはそう簡単ではない模様。
というわけで
ともあれ、docoptの民は、今後も、(docopt-ngで)生きていける。
実装後に改めてコメントを書く労力なしに、 ファイル冒頭説明に使い方が書かれていると、 ソースコードの格好がついて整っている感じで、 非常に気持ちが良い。
複雑なオプション解析が必要な場合とか、 高度なことをするには不足かもしれないが、 手軽に使う分には、最良の選択肢の一つだと思っている。
100 / (100+α) みたいな割り算の手計算を、近似でラクする
100や1000などのキリの良い数を、「それプラスちょっと」の数で割る、という手計算を迫られる場面が、 少なからず(少なからず?)ある。
ぱっと浮かぶ例としては、税込金額における税別金額の比率を求める、とか。
100 / 108 = ? (※食品テイクアウトは、いまも、8 %のはず。)
電卓なければ筆算なりをする訳だが、
108という桁多き数に9を掛けて、引き算して…と、初手から気が進まない。
しかも、分母分子が逆なら一瞬て答えが出るだけに、なんか癪に思う。
で、それなら割り算しないで近似すれば良いじゃん、という 気づきがあったので、ここに紹介する。
計算方法
キリのいい数字を$p$、加える"ちょっと"の数を$α$ とし、元の分数を $\frac{p}{p + α}$ で表す。
ここで、$x = \frac{α}{p}$ を計算する。
これはすなわち、$\frac{p}{p + α} = \frac{1}{1 + x}$を満たすxを求めていることに他ならない。# 先ほどの例(p=100、α=8)。 100 / 108 = 100 / (100 + 8) = 1 / (1 + 0.08) x = 0.08
近似式 $\frac{1}{1 + x} ≃ 1 - x + x ^ 2$ に基づいて、右辺を計算する。以上。
1 - x + x^2 = 1 - 0.08 + 0.08^2 = 1.0000 - 0.0800 + 0.0064 = 0.9264 よって、1/(1 + x) ≒ 0.9264
電卓で求めると 100 / 108 = 0.925925925... となって、誤差は 0.00048 ぐらいに収まる。
理屈
$f(x) = \frac{1}{1 + x}$のマクローリン展開による。
$\frac{1}{1 + x} = 1 - x + x ^ 2 - x ^ 3 + ...$
xのべき乗が並び、符号が交互に変わることに注意。 2乗の項までで近似したのが、先ほどの近似式である。
誤差$E$を計算すると、
$E = \left| \frac{1}{1 + x} - (1 - x + x ^ 2) \right| $
$\quad = \left| \frac{1}{1 + x} \{ 1 - (1 - x ^ 3) \} \right| = \left| \frac{x ^ 3}{1 + x} \right| $
となって、おおよそxの3乗のレベルの誤差となることが分かる。 $1 \gg x$なら、十分近似になる。
もちろん、精度を求めるなら、更に$x ^ 3$乗を引けば良い。αが1桁なら暗算でも何とかなるはず。
発展
p / (p - α) のパターンにも応用できる(例:1000/997)。
符号を変えた場合($x → -x$)でも当然成り立つから、
$\frac{p}{p - α} = \frac{1}{1 - x} ≃ 1 + x + x ^ 2$
を使えば良い。
また、α / (p + α) のパターン(例:8 / 108)であれば、 1から引く形に式変形すれば良い。
8 / 108 = 1 - 100/108 ≃ 1 - 0.9264 = 0.0736 (近似)
8 / 108 = 0.0740740... (真値)
tf.Kerasの事前学習済みモデルに、正則化を加える正しい方法
TensorFlow(2系)の、 tf.keras.applicationsの学習済みモデルに、L2等の正規化を加える方法。 次の記事[1]で解決した。簡単にその要約を記す。
[1] Silva TS. How to Add Regularization to Keras Pre-trained Models the Right Way. 2019.
- Step 1. 学習済みモデル
model
が持つのlayerを走査する。 もしlayerがkernel_regularizer
プロパティを持っていれば、そこにregularizerオブジェクトを代入する。
(ここで加えた変更は、まだ、モデルそのものには反映されていない。) - Step 2. 一旦、config情報をjsonに書き出す。
- Step 3. 重みについても、一時ファイルに書き出す。
- Step 4. Step 3で書き出したconfig情報(json)を読み戻す。
(regularizerは反映されるが、重みがリセットされてしまう。) - Step 5. Step 4で書き出した重みを、一時ファイルから読み戻す。
(https://gist.github.com/sthalles/d4e0c4691dc2be2497ba1cdbfe3bc2eb から引用)
import os import tempfile def add_regularization(model, regularizer=tf.keras.regularizers.l2(0.0001)): if not isinstance(regularizer, tf.keras.regularizers.Regularizer): print("Regularizer must be a subclass of tf.keras.regularizers.Regularizer") return model for layer in model.layers: for attr in ['kernel_regularizer']: if hasattr(layer, attr): setattr(layer, attr, regularizer) # When we change the layers attributes, the change only happens in the model config file model_json = model.to_json() # Save the weights before reloading the model. tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5') model.save_weights(tmp_weights_path) # load the model from the config model = tf.keras.models.model_from_json(model_json) # Reload the model weights model.load_weights(tmp_weights_path, by_name=True) return model
動作確認として、model.losses
にアクセスすれば、正則化の値をリストで取得することができる。
正則化を正しく反映できていれば、リストは空にならないはずである。
なお、プロパティとして、kernel_regularizer
だけでなく
bias_regularizer
とactivity_regularizer
も加えることが可能である。
Transformerを理解するまでに私が陥った、3つの勘違い
自然言語処理でお馴染み、他の分野も席巻しつつあるTransformerについて、 やっと自分の理解が追いついてきた。
Transformerとは何か、については、良い記事がたくさんあるのでそちらを参照されたい。
念のため、最小限の説明をすると、次のような感じ。
- 入力列から、辞書的に情報を拾い上げて出力列に伝える仕組み、もしくは、それを使ったモデル構造のこと。
- Attentionと呼ばれる機構が、「辞書的に情報を拾い上げ」る役割を実現。
以前に調べたときは、上辺だけをさらっと知ることを優先していたために、 結構勘違いが多かったことに気づかされた。そのことについて、ここにメモしておく。
(#上辺じゃなくて最初からちゃんと理解せえ、という自分への戒めも込めて。)
(2021.7.16 5月頃にメモとして書いていたものを、タイトル含め書き直して再投稿)
Transformerモデルについて
1. 「Transformer」という語が指す範囲は、文脈による
特に、TransformerとBERTをいっぺんに理解しようとしてはまった点。
私の理解の上では、(Original)Transformerの1層とBERTの1層が、 図のように対応する。
Attention Is All You Needの論文の解説記事だと、そもそも左の図全体を (Originalの) Transformerと呼ぶことが多い。一方で、その応用技術では、図の1層分をTransformerと捉えている印象がある。
なお、上の左右の図は、そもそも見比べるべきではないことに注意が必要である。
右図は、先行研究のELMo等の概念と比べるために書かれた図であって、 BERTの図が間違っている、とか、そういう批判をしたい訳ではないことを補足しておく。
2. 「BERT」は「(単なる)双方向Transformer」ではない
特に双方向RNN等を知っている上で、"Bidirectional Transformer"と言われると、 何か計算上の改良が加わっているのではないか、と勘違いが働きやすい(私だけ?)がそうではない。
では、BERTの何が双方向性を持つのか、というと、「学習対象のタスクが」である。
- 従来よく行われていた、「文章の n-1 語めまでを入力して、n 語めを推測する」タスクでは、前向き方向にしか推定が行えない (*1)。
- 対して、BERTで設定されたタスク(2つのうち)の1つは、「文章の n 語めを隠して入力し、隠された n 語めを推測する」というタスクである。これなら、前向きにも後ろ向きにも単語情報を活用する、という恩恵が得られる。
つまり、双方向の予測が働くように設定したタスクによって、Transformerが双方向性の予測にも強くなった、ということである。
(*1) 補足:
- 文をひっくり返して、後ろ向きも予測すればいいという考え方であれば、双方向RNNで導入されており、これは前身のELMoで採用されていた方法である。
3. Decoderは、再帰せずに文書生成を行える訳ではない
「TransformerはRNNでないから、並列処理できる」的な説明だけを見て これも誤認してしまった。これは「学習時」にのみ、当てはまる。
例として、原文「彼/は/あの/通り/沿い/に/住んでいる」とその訳文「He/lives/along/that/street」を考える。 この場合、Transformer翻訳モデルの入力と出力の関係は、次のようになる。
(Encoderへの入力/Decoderへの入力 → /Decoderに期待される出力)
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>] → [He]
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He] → [He][lives]
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives] → [He][lives][along]
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along] → [He][lives][along][that]
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along][that] → [He][lives][along][that][street]
- [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along][that][street] → [He][lives][along][that][street][<終>]
学習時は、途中までの予測が完璧だと仮定して、次の1単語の予測を行う。 RNNみたいに内部状態が要るわけではないので、6つ分をいっぺんに並列計算できる。
一方、テスト時は、当然「途中までの予測」なんてものは手元にないので、 6つ分をいっぺんに、ではなく、RNNのDecoderと同じようなイメージで、 1つずつ単語を増やして予測していく。
参考文献
- [1] Vaswani A, Shazeer N et al. Attention is all you need. NIPS, 2017.
- [2] Devlin J, Chang MW et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. 2018.
- [3] 作って理解する Transformer / Attention - Qiita
- [4] 深層学習界の大前提Transformerの論文解説! - Qiita
- [5] 【世界一分かりやすい解説】イラストでみるTransformer|Beginaid