kntty.hateblo.jp

ここに何か書く。

Retentive Networkについてのメモ

Microsoft Researchから、Retentive Network が提案された(2023年7月)。 興味深いので、その意義についてざっくり理解したことを、つらつらとメモしておく。

(諸々正確でない点は、ご容赦いただきたい。)

[1] Yutao Sun et al., Retentive Network: A Successor to Transformer for Large Language Models, 2023.

[2307.08621] Retentive Network: A Successor to Transformer for Large Language Models

Retention

Retentive Networkの中核を担うのが、Retention機構である。

初め、「Retentive=[訳]保持力のある」って何かと思ったが、 Attention の後継を意識しての Retention だと思われる。

現在、LLM (大規模言語モデル)の技術としては、Transformer (Attention) が席巻している。
後述するが、学習を並列的に行えたことも1つの要因となって、高い精度を達成している。

一方、文の長さ N に対して、計算量が大きく、長文を扱いにくい面がある。 この点、実は、ひと世代前のRNNの方が優れている点である。

Retentionの肝は、同じ単純な再帰式からの式変形によって、 RNN的にもAttention的にも表現(計算)できる機構を作ったことにある。

導出は論文の式(1)から(7)を参照。

Figure 3を引用する。

以下、Nは文の長さとし、便宜上、トークン (d次元)を "単語"と呼ぶことにする。

(a) 並列表現 (=Attention的な表現)

  • Q, K, Vで構成されていて、Transformerに明るい人は、Self-Attentionでよく見る構図だと思う。
  • Q, K, Vはサイズ N×d であり、入力の単語列 X = (X_1 ... X_N) から作る。
    Attentionと同様、Q, K, Vは単語単位で並列計算ができる。
  • D は、Transformerにおいて先読みを防ぐためのmaskと、再帰の時に掛ける定数 γ を兼ねたような行列である。
  • 図では表現されていないが、Q_nとK_nには $ e ^ {inθ} $ を乗算しておく。
    • これは、何単語目であるか、という、再帰の代わりの情報(の一部)に相当する。
    • これが、結果的に、TransformerでいうところのPosition Encodingの派生形(xPosというらしい)に類した形になっているらしく、面白いところ。

(b) 再帰表現 (=RNN的な表現)

  • Q_n, K_n, V_n は、(a)における Q, K, V の n行目に対応する。
  • 入力 X_n と中間状態 S_{n-1} さえあれば、n 番目より前の入力単語を全く使わずに、出力 O を計算できる。

(a)と(b)でQ, K, Vを共有できているので、学習時には (a) で並列的に学習、推論時には (b) で再帰的に推論、 と、文字通り良いとこ取りができる。

論文では、更に、(a)と(b)のハイブリッド型((c) Chunkwise表現)を提案し、学習にはこれを用いている。
例えば、長さN=4096の文を、チャンク長 C=512 で処理する場合。
文をチャンクに分け、各チャンクは並列計算した上で、得られる中間状態を次のチャンクに再帰的に渡すようにする。 こうすると、メモリ計算量も大きく削減できる(C2 << N2)。

ここでは説明しないが、以上から更に手を加えて、最終的な提案法としている。

  • Transformerで使われる、Multi-head化。
  • head毎に異なるγの値(定数)を与えることによる、マルチスケール化対応。
  • 正規化による計算の安定化。(QKT → QKT/√d )。など。

学習結果のグラフなどは論文を見て頂ければと思うが、計算量O(1)の威力はすごい。

RNNについての補足

RNNは、推論効率、メモリ効率が高い。これの意味することを説明しておく。 例として、文書生成のようなタスクで、「I read the book.」を生成することを 考える。

初期状態 s0 を持っておいて、

 ① model(s0, <s>) → (s1, "I")
 ② model(s1, "I")   → (s2, "read")
 ③ model(s2, "read") → (s3, "the")
 ④ model(s3, "the")  → (s4, "book")
 ⑤ model(s4, "book") → (s5, ".")

ここで、<s>は、最初であることを意味する特殊なトークン(ベクトル)。

これは、メモリ効率的には非常に優れている。 入力と s_{n-1} の2つのベクトルさえあれば、次の出力が計算できる。 これは、文の長さに依存しない。

一方、再帰計算であるからには、 上記を並列的に計算するのは困難である。 ⑤の入力に必要な s4 は、仮に全て正解を持っていたとしても、 ①、②、③、④と、順番に計算しない訳にはいかない。

Transformerについての補足

先ほどのRNNの説明の裏返しが、Transformerである。

Transformerは、中間状態を持たずに、現在の単語までの情報を 辞書的に拾ってくる、という戦略を取る。

先ほどの書き方に倣うと、次のようになる。

 ① model(<s>, -, -, -, -) → ("I", -, -, -, -)
 ② model(<s>, "I", -, -, -)   → ("I", "read", -, -, -)
 ③ model(<s>, "I", "read", -, -) → ("I", "read", "the", -, -)
 ④ model(<s>, "I", "read", "the", -) → ("I", "read", "the", "book", -)
 ⑤ model(<s>, "I", "read", "the", "book") → ("I", "read", "the", "book", ".")

「中間状態を入力する必要が無い」ということがポイントで、 正解が分かっている状態(つまり、学習時)であれば、 ①~⑤をそのまま並列に処理できる。

ただし、この並列性は、推論時は役立たないので、1単語ずつ、入力を増やしながら 再帰的に推論していくしかない。

更に悪いことに、n 番目の単語を予測するときに n - 1 の辞書引きをするために、 入力の長さの2乗に比例するテーブルを作る必要がある。 特に、文の長さ n が増大すると、GPUメモリ不足も深刻になってくる訳である。

docopt は、docopt-ng として生き続けている、という話

Pythonコマンドラインパーサと言ったら何を思い浮かべるだろう。

標準ライブラリ argparse の他、clickやfire などいろいろ選択肢があるが、私は docopt をずっと気に入って使っている。

で、docoptの(元祖)サイトを見に行くと、更新が止まっていてどうしたものか…と思ってしまうが、心配無用。

結論としては、docoptは、docopt-ng に引き継がれて メンテナンスが続いているよ、という話。

github.com

日本語で言及している人が少なそうなので、記事にしておく。

docoptの概要

ごく簡単に、docoptを紹介しておく。

docoptは、「規則に従って書いたヘルプメッセージから、 コマンドラインパーサを作る」という、通常とは逆転の発想を 持ったライブラリである。

具体的な書き方の例を示す。pythonスクリプトの冒頭のコメントに、Usageと、オプションの説明を書けば良い。

   ### 'Usage:'から空行まで、Usageパターンを認識 ###
   """
   Usage:
       my_program.py <file> [options]

   ### 空白除いて '-' から始まる行は、オプションとして認識 ###
   # 空白1個区切りで短いオプションと長いオプション。
   # 空白2個以上の後に、説明を記載。
   # [default: ]によってデフォルト値指定。
   Options:
       -p --port=<n>  何らかの番号
       -o <dir>       出力ディレクトリの指定 [default: ./output]
       --quiet        静かに実行
   """
   from docopt import docopt 

   args = docopt(__doc__)
   # print(args)

このように書いておけば、あとは、args["<file>"]として文字列を拾ったり、 args["--quiet"]としてboolを拾ったりすることができる。

インストールは、pipで。pip install docopt-ngとする。

その他、詳細はいろいろあるので、GitHubページを参照のこと。

なお、docopt自体はこれ以上の機能を持たないシンプルな設計なので、 型変換や範囲チェックなどのバリデーションは、(必要ならschemaなどを用いて)別途行う。

docoptからdocopt-ngに至る経緯

私の記憶と断片的な情報をもとに書いているので、全く以って正確でないとは思うが、記しておく。

オリジナルのdocoptは、PyCon UK 2012で出現して、注目を浴びた。 その発想から、他の様々な言語にも移植された。

そんなdocoptだが、(少なくとも)元の Python実装は、GitHub2018年でメンテが止まっていて久しい。

一方、それより前か後か、docoptからForkされた改良版の1つ、 docopt-ng が存在しており、それが jazzband という、Python系プロジェクトの維持を目的とした (メタ) プロジェクトのメンテナンス下に収まったようである。

ちなみに、メンテナンス版(docopt-ng)の方が docopt の名前を踏襲できないか、という issue に挙がったりしているけれど、先例 (PIL→pillowとか) に見られるように、名前を変えるのはそう簡単ではない模様。

というわけで

ともあれ、docoptの民は、今後も、(docopt-ngで)生きていける。

実装後に改めてコメントを書く労力なしに、 ファイル冒頭説明に使い方が書かれていると、 ソースコードの格好がついて整っている感じで、 非常に気持ちが良い。

複雑なオプション解析が必要な場合とか、 高度なことをするには不足かもしれないが、 手軽に使う分には、最良の選択肢の一つだと思っている。

100 / (100+α) みたいな割り算の手計算を、近似でラクする

100や1000などのキリの良い数を、「それプラスちょっと」の数で割る、という手計算を迫られる場面が、 少なからず(少なからず?)ある。

ぱっと浮かぶ例としては、税込金額における税別金額の比率を求める、とか。

100 / 108 = ? (※食品テイクアウトは、いまも、8 %のはず。)

電卓なければ筆算なりをする訳だが、 108という桁多き数に9を掛けて、引き算して…と、初手から気が進まない。
しかも、分母分子が逆なら一瞬て答えが出るだけに、なんか癪に思う。

で、それなら割り算しないで近似すれば良いじゃん、という 気づきがあったので、ここに紹介する。

計算方法

  1. キリのいい数字を$p$、加える"ちょっと"の数を$α$ とし、元の分数を $\frac{p}{p + α}$ で表す。
    ここで、$x = \frac{α}{p}$ を計算する。
    これはすなわち、$\frac{p}{p + α} = \frac{1}{1 + x}$を満たすxを求めていることに他ならない。

      # 先ほどの例(p=100、α=8)。
      100 / 108 = 100 / (100 + 8) = 1 / (1 + 0.08)
      x = 0.08
    
  2. 近似式 $\frac{1}{1 + x} ≃ 1 - x + x ^ 2$ に基づいて、右辺を計算する。以上。

     1 - x + x^2 = 1 - 0.08 + 0.08^2
                 = 1.0000 - 0.0800 + 0.0064
                 = 0.9264
     よって、1/(1 + x) ≒ 0.9264
    

電卓で求めると 100 / 108 = 0.925925925... となって、誤差は 0.00048 ぐらいに収まる。

理屈

$f(x) = \frac{1}{1 + x}$のマクローリン展開による。

$\frac{1}{1 + x} = 1 - x + x ^ 2 - x ^ 3 + ...$

xのべき乗が並び、符号が交互に変わることに注意。 2乗の項までで近似したのが、先ほどの近似式である。

誤差$E$を計算すると、

$E = \left| \frac{1}{1 + x} - (1 - x + x ^ 2) \right| $
$\quad = \left| \frac{1}{1 + x} \{ 1 - (1 - x ^ 3) \} \right| = \left| \frac{x ^ 3}{1 + x} \right| $

となって、おおよそxの3乗のレベルの誤差となることが分かる。 $1 \gg x$なら、十分近似になる。

もちろん、精度を求めるなら、更に$x ^ 3$乗を引けば良い。αが1桁なら暗算でも何とかなるはず。

発展

p / (p - α) のパターンにも応用できる(例:1000/997)。
符号を変えた場合($x → -x$)でも当然成り立つから、

$\frac{p}{p - α} = \frac{1}{1 - x} ≃ 1 + x + x ^ 2$

を使えば良い。

また、α / (p + α) のパターン(例:8 / 108)であれば、 1から引く形に式変形すれば良い。

8 / 108 = 1 - 100/108 ≃ 1 - 0.9264 = 0.0736 (近似)
8 / 108 = 0.0740740... (真値)

tf.Kerasの事前学習済みモデルに、正則化を加える正しい方法

TensorFlow(2系)の、 tf.keras.applicationsの学習済みモデルに、L2等の正規化を加える方法。 次の記事[1]で解決した。簡単にその要約を記す。

[1] Silva TS. How to Add Regularization to Keras Pre-trained Models the Right Way. 2019.

sthalles.github.io

  • Step 1. 学習済みモデルmodelが持つのlayerを走査する。 もしlayerがkernel_regularizerプロパティを持っていれば、そこにregularizerオブジェクトを代入する。
     (ここで加えた変更は、まだ、モデルそのものには反映されていない。)
  • Step 2. 一旦、config情報をjsonに書き出す。
  • Step 3. 重みについても、一時ファイルに書き出す。
  • Step 4. Step 3で書き出したconfig情報(json)を読み戻す。
    (regularizerは反映されるが、重みがリセットされてしまう。)
  • Step 5. Step 4で書き出した重みを、一時ファイルから読み戻す。

https://gist.github.com/sthalles/d4e0c4691dc2be2497ba1cdbfe3bc2eb から引用)

import os
import tempfile

def add_regularization(model, regularizer=tf.keras.regularizers.l2(0.0001)):

    if not isinstance(regularizer, tf.keras.regularizers.Regularizer):
      print("Regularizer must be a subclass of tf.keras.regularizers.Regularizer")
      return model

    for layer in model.layers:
        for attr in ['kernel_regularizer']:
            if hasattr(layer, attr):
              setattr(layer, attr, regularizer)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
    model.save_weights(tmp_weights_path)

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model

動作確認として、model.lossesにアクセスすれば、正則化の値をリストで取得することができる。 正則化を正しく反映できていれば、リストは空にならないはずである。

なお、プロパティとして、kernel_regularizerだけでなく bias_regularizeractivity_regularizerも加えることが可能である。

Transformerを理解するまでに私が陥った、3つの勘違い

自然言語処理でお馴染み、他の分野も席巻しつつあるTransformerについて、 やっと自分の理解が追いついてきた。

Transformerとは何か、については、良い記事がたくさんあるのでそちらを参照されたい。

念のため、最小限の説明をすると、次のような感じ。

  • 入力列から、辞書的に情報を拾い上げて出力列に伝える仕組み、もしくは、それを使ったモデル構造のこと。
  • Attentionと呼ばれる機構が、「辞書的に情報を拾い上げ」る役割を実現。

以前に調べたときは、上辺だけをさらっと知ることを優先していたために、 結構勘違いが多かったことに気づかされた。そのことについて、ここにメモしておく。

(#上辺じゃなくて最初からちゃんと理解せえ、という自分への戒めも込めて。)

(2021.7.16 5月頃にメモとして書いていたものを、タイトル含め書き直して再投稿)

Transformerモデルについて

1. 「Transformer」という語が指す範囲は、文脈による

特に、TransformerとBERTをいっぺんに理解しようとしてはまった点。

私の理解の上では、(Original)Transformerの1層とBERTの1層が、 図のように対応する。

f:id:kntty:20210713092013p:plain
Transformer[1]とBERT[2]の図の関係

Attention Is All You Needの論文の解説記事だと、そもそも左の図全体を (Originalの) Transformerと呼ぶことが多い。一方で、その応用技術では、図の1層分をTransformerと捉えている印象がある。

なお、上の左右の図は、そもそも見比べるべきではないことに注意が必要である。

右図は、先行研究のELMo等の概念と比べるために書かれた図であって、 BERTの図が間違っている、とか、そういう批判をしたい訳ではないことを補足しておく。

2. 「BERT」は「(単なる)双方向Transformer」ではない

特に双方向RNN等を知っている上で、"Bidirectional Transformer"と言われると、 何か計算上の改良が加わっているのではないか、と勘違いが働きやすい(私だけ?)がそうではない。

では、BERTの何が双方向性を持つのか、というと、「学習対象のタスクが」である。

  • 従来よく行われていた、「文章の n-1 語めまでを入力して、n 語めを推測する」タスクでは、前向き方向にしか推定が行えない (*1)。
  • 対して、BERTで設定されたタスク(2つのうち)の1つは、「文章の n 語めを隠して入力し、隠された n 語めを推測する」というタスクである。これなら、前向きにも後ろ向きにも単語情報を活用する、という恩恵が得られる。

つまり、双方向の予測が働くように設定したタスクによって、Transformerが双方向性の予測にも強くなった、ということである。

(*1) 補足:

  • 文をひっくり返して、後ろ向きも予測すればいいという考え方であれば、双方向RNNで導入されており、これは前身のELMoで採用されていた方法である。

3. Decoderは、再帰せずに文書生成を行える訳ではない

「TransformerはRNNでないから、並列処理できる」的な説明だけを見て これも誤認してしまった。これは「学習時」にのみ、当てはまる。

例として、原文「彼/は/あの/通り/沿い/に/住んでいる」とその訳文「He/lives/along/that/street」を考える。 この場合、Transformer翻訳モデルの入力と出力の関係は、次のようになる。

(Encoderへの入力/Decoderへの入力 → /Decoderに期待される出力)

  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>] → [He]
  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He] → [He][lives]
  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives] → [He][lives][along]
  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along] → [He][lives][along][that]
  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along][that] → [He][lives][along][that][street]
  • [彼][は][あの][通り][沿い][に][住んでいる][<終>] / [<空>][He][lives][along][that][street] → [He][lives][along][that][street][<終>]

学習時は、途中までの予測が完璧だと仮定して、次の1単語の予測を行う。 RNNみたいに内部状態が要るわけではないので、6つ分をいっぺんに並列計算できる。

一方、テスト時は、当然「途中までの予測」なんてものは手元にないので、 6つ分をいっぺんに、ではなく、RNNのDecoderと同じようなイメージで、 1つずつ単語を増やして予測していく。

参考文献