Layer Normalization#
Layer Normalization(LN) は,主に再起的な構造を持つRecurrent Neural Network(RNN)やTransformerモデルで利用される主に系列データに対する正規化手法である.Batch Normalization(BN)がミニバッチ全体で統計量(平均と分散)を計算するのに対し,LNは各データポイント内の特徴次元ごとに統計量を計算し正規化する.
具体的に,バッチサイズ \(B\) で \(D\) 次元の入力を \(\boldsymbol{x} \in \mathbb{R}^{B \times D}\) としたとき,各特徴次元の平均と分散を以下のように計算する.
ここで,\(i\) はバッチ内のデータポイントのインデックスである.この計算は各データポイントごとに行われるため,BNとは異なり,バッチサイズ \(B\) に依存しない.
次に,計算された統計量を使って各データポイントを正規化し,学習可能なスケールパラメータ \(\gamma \in \mathbb{R}^D\) とシフトパラメータ \(\beta \in \mathbb{R}^D\) を用いて調整する.
ここで,\(\varepsilon\) はゼロ除算を防ぐための微小定数である.
LNではデータポイントごとの統計量を用いるため,BNの移動平均による統計量の記録が必要ない.
LNは画像の場合は全チャネル・全画素の統計量を各データ単位で計算し,正規化する.ViTのような画像をパッチ化した場合,埋め込み次元に対して各パッチ単位で正規化される.LNを適用するには torch.nn.LayerNorm を次のように利用する.
torch.nn.LayerNorm(normalized_shape)
ここで,normalized_shape は特徴次元 \(D\) つまりViTの場合は埋め込み次元である.
import torch
import torch.nn as nn
x = torch.randn(10, 32)
norm = nn.LayerNorm(32)
h = norm(x)
print('x.shape:', x.shape)
print('h.shape:', h.shape)
x.shape: torch.Size([10, 32])
h.shape: torch.Size([10, 32])