kntty.hateblo.jp

ここに何か書く。

Attentionと、全結合・畳み込みとの関係

Transformerに使われるAttentionと、FFN(Position-wise Dense)、全結合(Dense)、畳み込み(Convolution)の 関係を俯瞰するために、お絵描きを試みたので、ここに載せる。

もしかしたら不正確な表現があるかもしれないが、ご容赦いただきたい。

なお、図では、下側を入力、上側を出力としている。

(2021/6/30: 図のインデックスの誤りを修正)

1. Dense (全結合)

全結合は、基本的に重み行列$ \bm W $の乗算で表される。

なお、通常はもうひとつ別の重みベクトル(バイアス)$\bm b$を用意して、 $\bm y = \bm W \bm x + \bm b$とするが、本記事では省略する。

f:id:kntty:20210630144745p:plain
Dense (no bias)

2. Convolution (畳み込み)

簡単のため、1次元の畳み込み演算(カーネルサイズ3、入出力の次元は同じ)とする。

畳み込み演算は、一部の重み(1次元畳み込みの場合は、行列の対角方向の要素)が共有され、かつ、 重みの外側がゼロである行列を用意することで、全結合と同様に行列演算として表すことができる。

バイアスを組み入れることの方が多いが、その場合は$\bm y = \bm C \bm x + \bm b$とする。

f:id:kntty:20210630144909p:plain
Convolution (1D)

3. Position-wise Dense

ここから2次元(Depth軸とPosition軸)。

Transformerブロックでも使われるPosition-wise Denseでは、 Depth軸だけを入力とした全結合(重み$\tilde \bm W$)を、 全てのPosition位置で行う。このとき、重み$\tilde \bm W$は、Position位置に依らず共有とする。

f:id:kntty:20210630144959p:plain
Position-wise Dense

4. Depth-wise Convolution

Transformerとは関係ないが、比較対象として。

Depth-wise Convolutionでは、全てのDepth位置で、Position方向に広がった畳み込みを行う。

畳み込む対象がスカラーからベクトルになった、と考えた方が分かりやすいかも。

なお、Depth-wise Convolutionの2次元版の応用例としては、 フルのConvolutionをDepth-wiseとPoint-wiseに分割しすることによる パラメータ数削減(モデル軽量化)が挙げられる。

f:id:kntty:20210630145033p:plain
Depth-wise Convolution (1D)

5. Self Attention / Source-Target Attention

以上を踏まえて、Attentionの基本的な部分を図にすると、このような感じ。

ValueもKeyもQueryも、Position-wise Denseに相当する方法で生成する。

Depth-wise Convolutionのような、ベクトルを重み付けして集約する機能があり、 その重み付けは、QueryとKeyの内積ベースで決まる。

完全に全結合として行数×列数分の要素を重みパラメータとするよりも、 行(q)と列(k)に分けて用意する方が、(少なくとも)重みパラメータは少なくて済むことが多いはず。

f:id:kntty:20210630145614p:plain
Attention

6. Multi-headed Attention

ついでに、Multi-headの場合。

図では全てを示せていないが、Valueだけではなく、KeyとQueryも同じように分割した後、 Attentionの計算をして、得られる結果を単純に結合する。

f:id:kntty:20210630145156p:plain
Multi-head Attention

指数移動平均

指数移動平均(Exponential Moving Average; EMA)を得るためのテクニックの話。

以前、確か、Temporal Ensembling [1]の論文を読んだときに、なるほどー、と思ったときのメモ書きが出てきたので、文字に起こしておく。

[1] Laine S, Aila T. Temporal Ensembling for Semi-Supervised Learning. 2017.


何某かの変数 $z _ n (n \ge 1)$ を反復更新しながら、その指数移動平均

x _ n = \frac {z_n + α z _ {n - 1} + α^2 z _ {n - 2} + ... + α^{n-1} z _ {1}} {1 + α + α^2 + ... + α^{n-1}}

を得たいとする。 ここで、αは定数($0 < α < 1$)。

もちろん$z _ i (1 \le i \le n)$を全て記録しておけば良いのだが、nが大きいとか、zが巨大ベクトルとかであれば、効率が悪い。

そこで、 zの更新と同時に、次のように$X$を更新することを考える。

$X _ n = z _ {n} + α X _ {n - 1} \quad (n \ge 1),$

$X _ 0 = 0 \quad (n =0)$

ここで、右辺の$X$に、式自身を繰り返し代入すると、

  • $X _ n = z _ {n} + α z _ {n - 1} + α ^ 2 X _ {n - 2}$
  • $X _ n = z _ {n} + α z _ {n - 1} + α ^ 2 z _ {n - 2} + α ^ 3 X _ {n - 3}$
    ...
  • $X _ n = (z _ {n} + αz _ {n - 1} + α ^ 2 z _ {n - 2} + ... α ^ {i} z _ {n - i} + ... +α ^ {n - 1} z _ {1} ) + α ^ {n} X _ {0}$

$X _ 0 = 0$であるから、

X _ n = \sum _ {i = 0} ^ {n - 1} α ^ {i} z _ {n - i}

これで、めでたく、最初の$x _ n$式の分子が得られる。 分母は、等比数列の和

$A _ n = \sum _ {i = 0} ^ {n-1} α ^ {i} = (1 - α ^ n)/(1 - α)$

であるから、次のように$x _ n$を得る。

x _ n = \frac{1 - α}{1 - α ^ n} X _ n

分散と不偏分散と標準偏差と標準誤差と

それぞれどちらを使うべきか、時々混乱するので、整理する。

分散と不偏分散の定義

以下、手元に$n$個のサンプル $x _ i$ があって、その平均が $ m $ であるとする。

  • (通常の)分散: $ σ ^ 2 _ P = \frac {1}{n} \sum _ i {(x _ i - m)} $
  • 不偏分散: $ σ ^ 2 _ S = \frac {1}{n - 1} \sum _ i {(x _ i - m)} $

Excelだと、前者が「VAR.P」、後者が「VAR.S」という関数に対応するので、 添え字はそれに倣っている。

分散と不偏分散の使い分け

  • 分散 $σ ^ 2 _ P$は、手元のサンプル n 個が、計算対象の全て(母集団そのもの)であるとき
  • 不偏分散 $σ ^ 2 _ S$は、ある母集団から任意個抽出してきた n 個が計算対象であるとき

に使う。

(不偏分散の方は、厳密には、手元の $n$ 個から、母集団$N (> n)$個の分散を 予想したい場合に使う、と書いた方が正確か。)

ちなみに、なぜ $n$ より小さい値($ n - 1$) で割るのか、その理由については、 次のリンク先の「数式を使わない感覚的な説明」の図が 非常に直感的で好き。

不偏標本分散の意味とn-1で割ることの証明 | 高校数学の美しい物語

標準偏差と標準誤差の定義

  • 標準偏差: $ SD = σ $
  • 標準誤差: $ SE = \frac{σ}{\sqrt n}$ (厳密には、標本平均の標準誤差)

標準偏差と標準誤差の使い分け

  • 標準偏差は、データのばらつき具合をの大きさを数値化して比べたいとき
  • 標準誤差は、ある推定量(特にここでは標本平均)の精度を見極めたいとき

に使う。

標準偏差の方は、「データそのもの」のばらつき具合であるから、 nを増やせば増やすほど、母集団のばらつき具合の値(つまり、標準誤差)に近づく。

標準誤差の方は「データの(推定)平均値」のばらつき具合、
言い換えれば「母集団のN個からランダムにn個取って、平均値を推定」してみる試行を 何回も行ったときの、推定値のばらつき具合である。
したがって、nを増やせば増やすほど確度が高まり、0に近づいていく。

なお、定義については次のサイトを参照した。

標準偏差と標準誤差の違いをわかりやすく!計算式やエラーバーでの使い分けは?|いちばんやさしい、医療統計

t-SNEやUMAPで、裾の広い分布を使う理由(混雑問題)

UMAPの記事で、距離ではなく「近さ」という言葉を使った。

単純に「距離」を当てはめるのでは上手くいかない理由の1つが、混雑問題(Crowding Problem)である。

混雑問題とは、「高次元で"同じ距離"を表せる範囲が、低次元では極端に狭くなる」(という解釈で良いのだと思う)。

f:id:kntty:20201208202951j:plain

この事実の根拠として、同時に等間隔に配置できる点の数での説明がよくなされる。

  • 3次元では4個まで(正四面体の配置)
  • 2次元では3個まで(正三角形の配置)

つまり、高次元の距離を保ったまま、低次元に移すことは、本質的に無理な例が生じる、ということである。

t-SNEでは、Crowding Probremの対策として、高次元と低次元で異なる分布関数を使っている(恐らく、UMAPでも同じ恩恵に預かっている)。

f:id:kntty:20201208202800j:plain