残差学習

残差学習#

残差学習(Residual Learning) は,多層ニューラルネットワークの学習において,勾配を出力から入力まで効果的に逆伝播する仕組みであり,ResNetで提案されたアイデアである.

ニューラルネットワークは多層になるにつれて,より複雑なパターンを効果的に学習できるようになるが,同時に入力層まで逆伝播する間に勾配が消失または発散する問題が知られており,実際にはニューラルネットの層数を極端に増やすと学習に失敗することが報告されていた.

残差学習では,層に ショートカット接続(skip connection) と呼ばれる経路を追加することで,入力をそのまま次の層に伝えるというアイデアを導入している.これを数式で書くと,ある層への入力を \(x\),層の変換を \(F(x)\) としたとき,その層の出力 \(y\) を次のように定義する.

\[ y = F(x) + x \]

これを残差学習(Residual Learning)という.入力を層の適用結果に加算しただけであるが,層のパラメータは入力と出力の差分(残差)

\[ F(x) = y - x \]

を学習するだけでよく,また出力 \(y\) を入力 \(x\) で微分すると,

\[ \frac{\partial y}{\partial x}=\frac{\partial F(x)}{\partial x}+\frac{\partial x}{\partial x}=\frac{\partial F(x)}{\partial x}+1 \]

となり,後段から逆伝播されてきた勾配が上式の第二項目にかけられて次の層へ直接逆伝播する.これが残差学習の利点であり,勾配の消失・発散問題を回避する理由である.

実際,ResNetでは,この残差のアイデアで100層以上のニューラルネットワークを学習させることに成功し,高い性能を達成した.そして,この残差はResNetだけでなくViTの実装でも現れたように,モデル構造の要素として普遍的なテクニックとなっている.