Self-Attention#

このノートブックでは,Transformerのコア要素である 自己注意機構(Self-Attention) について説明する.自己注意機構は自身のトークンと関連するトークンの情報を使って,自身のトークンの情報を強調したり,逆に減衰したり更新することができる仕組みである.そのため,自己注意機構内ではトークン間の類似度の計算と特徴の強調・減衰を行う加重和の計算が行われる.

Query, Key, Value#

../../_images/fig_qkv.png

まず,Self-Attentionは,ある入力系列 \(\boldsymbol{X}=\left(\boldsymbol{x}_1, \boldsymbol{x}_2, \ldots, \boldsymbol{x}_N\right) \in \mathbb{R}^{N \times D}\) が与えられたとき,\(D\) 次元の各トークン \(\boldsymbol{x}_i\) は他のトークン \(\boldsymbol{x}_j\) と関連付けられる.具体的に,各トークン \(\boldsymbol{x}_i\) は Query(\(\boldsymbol{Q}\)),Key(\(\boldsymbol{K}\)),および Value(\(\boldsymbol{V}\))のベクトルに線形層によって次のようにマッピングされる.

\[ \boldsymbol{Q} = \boldsymbol{X} \boldsymbol{W}_Q, \quad \boldsymbol{K} = \boldsymbol{X} \boldsymbol{W}_K, \quad \boldsymbol{V} = \boldsymbol{X} \boldsymbol{W}_V \]

ここで,\(D\) 次元から \(D'\) 次元へ変換する \(\boldsymbol{W}_Q\)\(\boldsymbol{W}_K\)\(\boldsymbol{W}_V\) はそれぞれの線形層の重み行列である.

それぞれの役割のイメージについて簡単に述べておこう.Query は他のトークンからどの情報を引き出すべきかを決定するための「質問」を表すベクトルである.各トークンが他のトークンに「何を求めているか」,つまりどのような情報を重視するかを示している.Key は各トークンが「どんな情報を持っているか」を示すベクトルであり,Queryからの質問に対する「鍵」を表す.Keyは各トークンが持つ情報の特徴を表しており,他のトークンが自身のQueryと照らし合わせたときに,関連性を評価する指標となっている.そして,Value は,実際に注意を向けるべき情報の内容を持つベクトルであり,各トークンが提供する「値」である.

では,この線形層の実装を行おう.まずはダミーの入力を作成する.

import torch

num_tokens = 10
embed_dims = 32
x = torch.randn((num_tokens, embed_dims))

print('x.shape:', x.shape)
x.shape: torch.Size([10, 32])

各線形層を次のように定義する.

import torch.nn as nn

proj_q = nn.Linear(embed_dims, embed_dims, bias=False)
proj_k = nn.Linear(embed_dims, embed_dims, bias=False)
proj_v = nn.Linear(embed_dims, embed_dims, bias=False)

線形層を適用して,Query(\(\boldsymbol{Q}\)),Key(\(\boldsymbol{K}\)),および Value(\(\boldsymbol{V}\))を計算する.

q = proj_q(x)
k = proj_k(x)
v = proj_v(x)

print('q.shape:', q.shape)
print('k.shape:', k.shape)
print('v.shape:', v.shape)
q.shape: torch.Size([10, 32])
k.shape: torch.Size([10, 32])
v.shape: torch.Size([10, 32])

類似度#

では,QueryとKeyの関連,つまり,質問に対するもっとも関連するトークンを求めたい.線形層によって変換されたQueryとKeyはそれぞれ \(N \times D'\) の行列であるので,それぞれの行列の \(i\) 番目と \(j\) 番目のトークン \(\boldsymbol{Q}_i\)\(\boldsymbol{K}_j\) を考える.

これらのトークンはベクトルであるので,ベクトル間の類似度は一般的に,内積 もしくは コサイン類似度 を用いて計算される.内積を用いた類似度計算は,次のように表される.

\[ \operatorname{similarity}\left(\boldsymbol{Q}_i, \boldsymbol{K}_j\right)=\sum_{d=1}^{D'} Q_{i, d} \cdot K_{j, d} \]

ベクトル同士が同じ方向に向いている場合は大きな値,逆方向に向いている場合は小さな値(負の値)になる.そして,内積が0のときはベクトルが直交しており,ベクトル間が独立していることを示す.

また,ベクトルの長さ(ノルム)による影響を除くために,コサイン類似度を用いることが一般的である.コサイン類似度は以下のように定式化される.

\[ \operatorname{cosine\_ similarity }\left(\boldsymbol{Q}_i, \boldsymbol{K}_j\right)=\frac{\boldsymbol{Q}_i \cdot \boldsymbol{K}_j}{\left\|\boldsymbol{Q}_i\right\|\left\|\boldsymbol{K}_j\right\|} \]

コサイン類似度も内積と同様に値から類似度を解釈できる.

ViTでは内積を用いて類似度を計算する.QueryとKeyの全トークンに関する内積は次のように転置してから行列計算すれば各トークン間の関連を一つの行列にまとめることができる.

\[ \operatorname{similarity \_ matrix}\left(\boldsymbol{Q}, \boldsymbol{K}\right)=\boldsymbol{Q}\boldsymbol{K}^\top \in \mathbb{R}^{N \times N} \]

計算された Query(\(\boldsymbol{Q}\))とKey(\(\boldsymbol{K}\))の内積を上式から計算してみよう.

sim_mat = q @ k.T

print('sim_mat.shape:', sim_mat.shape)
sim_mat.shape: torch.Size([10, 10])

Attention Weight#

../../_images/fig_attn.png

類似度計算から得られる類似度を表す行列は,この後Valueとの加重和を計算するので,次のようにsoftmax関数を適用してスケーリングとピークを持つような行列にする.

\[ \boldsymbol{A} = \operatorname{softmax} \left( \frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{D'}} \right) \in \mathbb{R}^{N \times N} \]

このとき,ベクトルの次元が大きくなるほど内積の値は大きくなるので,大きくなりすぎないように \(\sqrt{D'}\) で割っている.

では,このAttention Weightを実装しよう.

import torch.nn.functional as F
attn_weight = F.softmax((q @ k.T) * embed_dims ** -0.5, dim=1)

これをHeatmapとして可視化する.

import matplotlib.pyplot as plt

plt.imshow(attn_weight.detach().numpy(), cmap='hot')
plt.ylabel('Query')
plt.xlabel('Key')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x154ff3e48640>
../../_images/8fb1158d8be27fcf7dffb1ac88fadafbe51488d61982f79230691a1b9727bcb9.png

未学習なので適当な類似度が格納されたAttention Weightの可視化である点に注意されたいが,このような可視化からAttentionのイメージを掴んでもらえたらと思う.

加重和#

../../_images/fig_sa.png

そして,このAttention Weightから実際に注意を向けるべき情報の内容を持つベクトル Value の値を取り出すために,次のように加重和を計算する.

\[ \boldsymbol{x}'_i = \sum_j \boldsymbol{A}_{ij} \cdot \boldsymbol{V}_j \]

上記は各トークンに対する処理であるが,以下のように行列計算もできる.

\[ \boldsymbol{X} = \boldsymbol{A} \boldsymbol{V} \in \mathbb{R}^{N \times D'} \]

ここまでが1つのSelf-Attentionの処理となる.

これを実装すると次のようになる.

output = attn_weight @ v

print('output.shape:', output.shape)
output.shape: torch.Size([10, 32])

Multi-Head Self-Attentionを含むSelf-Attentionの実装#

../../_images/fig_mhsa.png

ここまでの処理の問題はsoftmaxを適用するため1つのトークンに対しておおよそ1つの値のピークしか生じないことである.また複数のAttentionがあるほうが直感的にも性能が上がりそうである.そしてこのSelf-Attentionを複数連ねたものが Multi-Head Self-Attention(MHSA) であり,異なる注意パターンを同時に学習できる.

MHSAの実装は \(D' = D_{head} * N_{head}\) というように設定すれば,次のように計算結果をヘッドの個数分だけ分割することで,複数のQuery,Key,Valueを用意できる.これを踏まえて,Self-Attention層としてPyTorchで実装する.

class Attention(nn.Module):
    def __init__(self, dim, num_heads=4, dropout=0.5):
        super().__init__()
        self.head_dim = dim // num_heads
        self.num_heads = num_heads
        
        self.proj_q = nn.Linear(dim, dim, bias=False)
        self.proj_k = nn.Linear(dim, dim, bias=False)
        self.proj_v = nn.Linear(dim, dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        bs, num_tokens, dim = x.shape
        
        q = self.proj_q(x)
        k = self.proj_q(x)
        v = self.proj_q(x)
        
        q = q.reshape(bs, num_tokens, self.num_heads, self.head_dim)
        k = k.reshape(bs, num_tokens, self.num_heads, self.head_dim)
        v = v.reshape(bs, num_tokens, self.num_heads, self.head_dim)
    
        attn_weight = q @ k.transpose(-2, -1) * dim ** -0.5
        attn_weight = F.softmax(attn_weight, dim=-1)
        attn_weight = self.dropout(attn_weight)
        x = attn_weight @ v
        
        x = x.transpose(1, 2).reshape(bs, num_tokens, dim)
        x = self.proj(x)
        x = self.dropout(x)
        return x

bs = 1
num_tokens = 10
embed_dims = 32
x = torch.randn((bs, num_tokens, embed_dims))

attention = Attention(embed_dims)
y = attention(x)

print('y:', y.shape)
y: torch.Size([1, 10, 32])

そして,最終的に全てのヘッドの出力を結合した後,線形層とドロップアウトを適用することが多い.また,計算効率等を無視した非常に簡単な実装なので注意されたい.