Layer Normalization

Layer Normalization#

Layer Normalization(LN) は,主に再起的な構造を持つRecurrent Neural Network(RNN)やTransformerモデルで利用される主に系列データに対する正規化手法である.Batch Normalization(BN)がミニバッチ全体で統計量(平均と分散)を計算するのに対し,LNは各データポイント内の特徴次元ごとに統計量を計算し正規化する.

具体的に,バッチサイズ \(B\)\(D\) 次元の入力を \(\boldsymbol{x} \in \mathbb{R}^{B \times D}\) としたとき,各特徴次元の平均と分散を以下のように計算する.

\[\begin{split} \begin{align} \mu_i &= \frac{1}{D} \sum_{j=1}^D x_{ij} \\ \sigma_i^2 &= \frac{1}{D} \sum_{j=1}^D \left(x_{ij} - \mu_i \right)^2 \end{align} \end{split}\]

ここで,\(i\) はバッチ内のデータポイントのインデックスである.この計算は各データポイントごとに行われるため,BNとは異なり,バッチサイズ \(B\) に依存しない.

次に,計算された統計量を使って各データポイントを正規化し,学習可能なスケールパラメータ \(\gamma \in \mathbb{R}^D\) とシフトパラメータ \(\beta \in \mathbb{R}^D\) を用いて調整する.

\[ f_{LN}(x_{ij}) = \frac{x_{ij} - \mu_i}{\sqrt{\sigma_i^2 + \varepsilon}} \cdot \gamma + \beta \]

ここで,\(\varepsilon\) はゼロ除算を防ぐための微小定数である.

LNではデータポイントごとの統計量を用いるため,BNの移動平均による統計量の記録が必要ない.

LNは画像の場合は全チャネル・全画素の統計量を各データ単位で計算し,正規化する.ViTのような画像をパッチ化した場合,埋め込み次元に対して各パッチ単位で正規化される.LNを適用するには torch.nn.LayerNorm を次のように利用する.

torch.nn.LayerNorm(normalized_shape)

ここで,normalized_shape は特徴次元 \(D\) つまりViTの場合は埋め込み次元である.

import torch
import torch.nn as nn

x = torch.randn(10, 32)
norm = nn.LayerNorm(32)
h = norm(x)

print('x.shape:', x.shape)
print('h.shape:', h.shape)
x.shape: torch.Size([10, 32])
h.shape: torch.Size([10, 32])