Retentive Networkについてのメモ
Microsoft Researchから、Retentive Network が提案された(2023年7月)。 興味深いので、その意義についてざっくり理解したことを、つらつらとメモしておく。
(諸々正確でない点は、ご容赦いただきたい。)
[1] Yutao Sun et al., Retentive Network: A Successor to Transformer for Large Language Models, 2023.
[2307.08621] Retentive Network: A Successor to Transformer for Large Language Models
Retention
Retentive Networkの中核を担うのが、Retention機構である。
初め、「Retentive=[訳]保持力のある」って何かと思ったが、 Attention の後継を意識しての Retention だと思われる。
現在、LLM (大規模言語モデル)の技術としては、Transformer (Attention) が席巻している。
後述するが、学習を並列的に行えたことも1つの要因となって、高い精度を達成している。
一方、文の長さ N に対して、計算量が大きく、長文を扱いにくい面がある。 この点、実は、ひと世代前のRNNの方が優れている点である。
Retentionの肝は、同じ単純な再帰式からの式変形によって、 RNN的にもAttention的にも表現(計算)できる機構を作ったことにある。
導出は論文の式(1)から(7)を参照。
Figure 3を引用する。
以下、Nは文の長さとし、便宜上、トークン (d次元)を "単語"と呼ぶことにする。
(a) 並列表現 (=Attention的な表現)
- Q, K, Vで構成されていて、Transformerに明るい人は、Self-Attentionでよく見る構図だと思う。
- Q, K, Vはサイズ N×d であり、入力の単語列 X = (X_1 ... X_N) から作る。
Attentionと同様、Q, K, Vは単語単位で並列計算ができる。 - D は、Transformerにおいて先読みを防ぐためのmaskと、再帰の時に掛ける定数 γ を兼ねたような行列である。
- 図では表現されていないが、Q_nとK_nには $ e ^ {inθ} $ を乗算しておく。
- これは、何単語目であるか、という、再帰の代わりの情報(の一部)に相当する。
- これが、結果的に、TransformerでいうところのPosition Encodingの派生形(xPosというらしい)に類した形になっているらしく、面白いところ。
(b) 再帰表現 (=RNN的な表現)
- Q_n, K_n, V_n は、(a)における Q, K, V の n行目に対応する。
- 入力 X_n と中間状態 S_{n-1} さえあれば、n 番目より前の入力単語を全く使わずに、出力 O を計算できる。
(a)と(b)でQ, K, Vを共有できているので、学習時には (a) で並列的に学習、推論時には (b) で再帰的に推論、 と、文字通り良いとこ取りができる。
論文では、更に、(a)と(b)のハイブリッド型((c) Chunkwise表現)を提案し、学習にはこれを用いている。
例えば、長さN=4096の文を、チャンク長 C=512 で処理する場合。
文をチャンクに分け、各チャンクは並列計算した上で、得られる中間状態を次のチャンクに再帰的に渡すようにする。
こうすると、メモリ計算量も大きく削減できる(C2 << N2)。
ここでは説明しないが、以上から更に手を加えて、最終的な提案法としている。
- Transformerで使われる、Multi-head化。
- head毎に異なるγの値(定数)を与えることによる、マルチスケール化対応。
- 正規化による計算の安定化。(QKT → QKT/√d )。など。
学習結果のグラフなどは論文を見て頂ければと思うが、計算量O(1)の威力はすごい。
RNNについての補足
RNNは、推論効率、メモリ効率が高い。これの意味することを説明しておく。 例として、文書生成のようなタスクで、「I read the book.」を生成することを 考える。
初期状態 s0 を持っておいて、
① model(s0, <s>) → (s1, "I") ② model(s1, "I") → (s2, "read") ③ model(s2, "read") → (s3, "the") ④ model(s3, "the") → (s4, "book") ⑤ model(s4, "book") → (s5, ".")
ここで、<s>
は、最初であることを意味する特殊なトークン(ベクトル)。
これは、メモリ効率的には非常に優れている。 入力と s_{n-1} の2つのベクトルさえあれば、次の出力が計算できる。 これは、文の長さに依存しない。
一方、再帰計算であるからには、 上記を並列的に計算するのは困難である。 ⑤の入力に必要な s4 は、仮に全て正解を持っていたとしても、 ①、②、③、④と、順番に計算しない訳にはいかない。
Transformerについての補足
先ほどのRNNの説明の裏返しが、Transformerである。
Transformerは、中間状態を持たずに、現在の単語までの情報を 辞書的に拾ってくる、という戦略を取る。
先ほどの書き方に倣うと、次のようになる。
① model(<s>, -, -, -, -) → ("I", -, -, -, -) ② model(<s>, "I", -, -, -) → ("I", "read", -, -, -) ③ model(<s>, "I", "read", -, -) → ("I", "read", "the", -, -) ④ model(<s>, "I", "read", "the", -) → ("I", "read", "the", "book", -) ⑤ model(<s>, "I", "read", "the", "book") → ("I", "read", "the", "book", ".")
「中間状態を入力する必要が無い」ということがポイントで、 正解が分かっている状態(つまり、学習時)であれば、 ①~⑤をそのまま並列に処理できる。
ただし、この並列性は、推論時は役立たないので、1単語ずつ、入力を増やしながら 再帰的に推論していくしかない。
更に悪いことに、n 番目の単語を予測するときに n - 1 の辞書引きをするために、 入力の長さの2乗に比例するテーブルを作る必要がある。 特に、文の長さ n が増大すると、GPUメモリ不足も深刻になってくる訳である。