kntty.hateblo.jp

ここに何か書く。

Retentive Networkについてのメモ

Microsoft Researchから、Retentive Network が提案された(2023年7月)。 興味深いので、その意義についてざっくり理解したことを、つらつらとメモしておく。

(諸々正確でない点は、ご容赦いただきたい。)

[1] Yutao Sun et al., Retentive Network: A Successor to Transformer for Large Language Models, 2023.

[2307.08621] Retentive Network: A Successor to Transformer for Large Language Models

Retention

Retentive Networkの中核を担うのが、Retention機構である。

初め、「Retentive=[訳]保持力のある」って何かと思ったが、 Attention の後継を意識しての Retention だと思われる。

現在、LLM (大規模言語モデル)の技術としては、Transformer (Attention) が席巻している。
後述するが、学習を並列的に行えたことも1つの要因となって、高い精度を達成している。

一方、文の長さ N に対して、計算量が大きく、長文を扱いにくい面がある。 この点、実は、ひと世代前のRNNの方が優れている点である。

Retentionの肝は、同じ単純な再帰式からの式変形によって、 RNN的にもAttention的にも表現(計算)できる機構を作ったことにある。

導出は論文の式(1)から(7)を参照。

Figure 3を引用する。

以下、Nは文の長さとし、便宜上、トークン (d次元)を "単語"と呼ぶことにする。

(a) 並列表現 (=Attention的な表現)

  • Q, K, Vで構成されていて、Transformerに明るい人は、Self-Attentionでよく見る構図だと思う。
  • Q, K, Vはサイズ N×d であり、入力の単語列 X = (X_1 ... X_N) から作る。
    Attentionと同様、Q, K, Vは単語単位で並列計算ができる。
  • D は、Transformerにおいて先読みを防ぐためのmaskと、再帰の時に掛ける定数 γ を兼ねたような行列である。
  • 図では表現されていないが、Q_nとK_nには $ e ^ {inθ} $ を乗算しておく。
    • これは、何単語目であるか、という、再帰の代わりの情報(の一部)に相当する。
    • これが、結果的に、TransformerでいうところのPosition Encodingの派生形(xPosというらしい)に類した形になっているらしく、面白いところ。

(b) 再帰表現 (=RNN的な表現)

  • Q_n, K_n, V_n は、(a)における Q, K, V の n行目に対応する。
  • 入力 X_n と中間状態 S_{n-1} さえあれば、n 番目より前の入力単語を全く使わずに、出力 O を計算できる。

(a)と(b)でQ, K, Vを共有できているので、学習時には (a) で並列的に学習、推論時には (b) で再帰的に推論、 と、文字通り良いとこ取りができる。

論文では、更に、(a)と(b)のハイブリッド型((c) Chunkwise表現)を提案し、学習にはこれを用いている。
例えば、長さN=4096の文を、チャンク長 C=512 で処理する場合。
文をチャンクに分け、各チャンクは並列計算した上で、得られる中間状態を次のチャンクに再帰的に渡すようにする。 こうすると、メモリ計算量も大きく削減できる(C2 << N2)。

ここでは説明しないが、以上から更に手を加えて、最終的な提案法としている。

  • Transformerで使われる、Multi-head化。
  • head毎に異なるγの値(定数)を与えることによる、マルチスケール化対応。
  • 正規化による計算の安定化。(QKT → QKT/√d )。など。

学習結果のグラフなどは論文を見て頂ければと思うが、計算量O(1)の威力はすごい。

RNNについての補足

RNNは、推論効率、メモリ効率が高い。これの意味することを説明しておく。 例として、文書生成のようなタスクで、「I read the book.」を生成することを 考える。

初期状態 s0 を持っておいて、

 ① model(s0, <s>) → (s1, "I")
 ② model(s1, "I")   → (s2, "read")
 ③ model(s2, "read") → (s3, "the")
 ④ model(s3, "the")  → (s4, "book")
 ⑤ model(s4, "book") → (s5, ".")

ここで、<s>は、最初であることを意味する特殊なトークン(ベクトル)。

これは、メモリ効率的には非常に優れている。 入力と s_{n-1} の2つのベクトルさえあれば、次の出力が計算できる。 これは、文の長さに依存しない。

一方、再帰計算であるからには、 上記を並列的に計算するのは困難である。 ⑤の入力に必要な s4 は、仮に全て正解を持っていたとしても、 ①、②、③、④と、順番に計算しない訳にはいかない。

Transformerについての補足

先ほどのRNNの説明の裏返しが、Transformerである。

Transformerは、中間状態を持たずに、現在の単語までの情報を 辞書的に拾ってくる、という戦略を取る。

先ほどの書き方に倣うと、次のようになる。

 ① model(<s>, -, -, -, -) → ("I", -, -, -, -)
 ② model(<s>, "I", -, -, -)   → ("I", "read", -, -, -)
 ③ model(<s>, "I", "read", -, -) → ("I", "read", "the", -, -)
 ④ model(<s>, "I", "read", "the", -) → ("I", "read", "the", "book", -)
 ⑤ model(<s>, "I", "read", "the", "book") → ("I", "read", "the", "book", ".")

「中間状態を入力する必要が無い」ということがポイントで、 正解が分かっている状態(つまり、学習時)であれば、 ①~⑤をそのまま並列に処理できる。

ただし、この並列性は、推論時は役立たないので、1単語ずつ、入力を増やしながら 再帰的に推論していくしかない。

更に悪いことに、n 番目の単語を予測するときに n - 1 の辞書引きをするために、 入力の長さの2乗に比例するテーブルを作る必要がある。 特に、文の長さ n が増大すると、GPUメモリ不足も深刻になってくる訳である。