Batch Normalization

Batch Normalization#

バッチ正規化(Batch Normalization; BN) は各層の入力を正規化することで学習を安定化させる手法である.特に,各特徴量の分布が学習過程で変動すること(内部共変量シフト)に対処できる.

BNへの入力を \(\boldsymbol{x} \in \mathbb{R}^{B \times D}\),ここで \(B\) はミニバッチサイズ,\(D\) は特徴次元数とする.はじめに,BNでは,入力のバッチにわたる平均 \(\mu \in \mathbb{R}^{D}\) と分散 \(\sigma^2 \in \mathbb{R}^{D}\) を計算する.

\[\begin{split} \begin{align}\mu&=\frac{1}{B} \sum_{i=1}^B x_i \\ \sigma^2&=\frac{1}{B} \sum_{i=1}^B\left(x_i-\mu\right)^2 \end{align} \end{split}\]

ここで,\(x_i\)\(i\) 番目の \(D\) 次元のデータポイントである.そして,計算された統計量を基づいてデータを次のように正規化する.この正規化処理 \(N(\boldsymbol{x})\) は次のようになる.

\[ \begin{align} N(\boldsymbol{x})=\frac{\boldsymbol{x}-\mu}{\sqrt{\sigma^2+\varepsilon}} \end{align} \]

正規化された入力 \(\boldsymbol{x}\) の次元は \(\mathbb{R}^{B \times D}\) であり,\(\varepsilon\) はゼロ除算を防ぐための微小定数である.最後に,BNは学習可能なスケールパラメータ \(\gamma \in \mathbb{R}^{D}\) とシフトパラメータ \(\beta \in \mathbb{R}^{D}\) を用いて正規化された入力特徴分布を調整する.

\[ \begin{align} f_{BN}(\boldsymbol{x})=N(\boldsymbol{x})\gamma + \beta \end{align} \]

以上がBNの処理である.学習時は,ミニバッチごとに統計量を計算するが,推論時は学習時に計算した統計量の移動平均を用いる

\[\begin{split} \begin{align} \mu_\text{running} &= (1 - \alpha)\mu_\text{running} + \alpha \mu_\text{batch} \\ \sigma^2_\text{running} &= (1 - \alpha) \sigma^2_\text{running} + \alpha \sigma^2_\text{batch} \end{align} \end{split}\]

ここで,\(\alpha\) は移動平均の更新率を表すハイパーパラメータである.

また説明にはベクトルの入力を扱ったが画像に対しても同様に定義できる.2次元画像に対してBNを適用するには torch.nn.BatchNorm2d を次のように利用する.

torch.nn.BatchNorm2d(num_features)

ここで,num_features\(D\) つまり画像の場合はチャネル数である.

import torch
import torch.nn as nn

x = torch.randn(10, 3, 5, 5)
norm = nn.BatchNorm2d(3)
h = norm(x)

print('x.shape:', x.shape)
print('h.shape:', h.shape)
x.shape: torch.Size([10, 3, 5, 5])
h.shape: torch.Size([10, 3, 5, 5])