AdamW#

このノートブックでは,主にViTの学習で用いられるAdamWについて紹介する.AdamWはAdamと呼ばれるパラメータ単位で適応的に学習率を調整するオプティマイザとWeight Decayの関連性について言及し,改良したオプティマイザである.

Adam#

Adamは確率的最適化アルゴリズムの一種であり,生成モデルなどの学習が不安定なモデルの学習に有効なオプティマイザである.Adamの特徴は過去の勾配の1次モーメントと2次モーメントをそれぞれ推定し,これらを用いてパラメータ単位の学習率を適応的に調整する.\(t\) ステップ目で得られた勾配 \(g_t\) を用いてパラメータ \(\theta_t\) を更新するためのAdamの更新式は以下である.

\[\begin{split} \begin{aligned} m_t & =\beta_1 \cdot m_{t-1}+\left(1-\beta_1\right) \cdot g_t \\ v_t & =\beta_2 \cdot v_{t-1}+\left(1-\beta_2\right) \cdot g_t^2 \\ \hat{m}_t & =\frac{m_t}{1-\beta_1^t} \\ \hat{v}_t & =\frac{v_t}{1-\beta_2^t} \\ \theta_{t+1} & =\theta_t-\beta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon} \end{aligned} \end{split}\]

ここで,\(\beta\) は学習率,\(\beta_1, \beta_2\) は1次と2次モーメントの係数,\(\epsilon\) は数値計算のための安定化項である.

AdamはSGDと同様にtorch.optim内で実装されており,次のように利用できる.

optimizer = optim.Adam(model.parameters())

AdamW#

証明は省略するが,Adamにおいて,\(L_2\) 正則化とWeight DecayがAdamでは等価ではない(SGDでは等価)ので,ナイーブな \(L_2\) 正則化の実装

\[ g_t = g_t + \lambda \theta \]

を使って,Adamの更新式に基づきパラメータ更新を行うと,\(L_2\) 正則化がうまく働かない.そこで,モーメンタムによる勾配更新とWeight Decayを切り離すことが考えられる.これはAdamWと呼ばれ,以下の更新式で与えられる.

\[\begin{split} \begin{aligned} m_t & =\beta_1 \cdot m_{t-1}+\left(1-\beta_1\right) \cdot g_t \\ v_t & =\beta_2 \cdot v_{t-1}+\left(1-\beta_2\right) \cdot g_t^2 \\ \hat{m}_t & =\frac{m_t}{1-\beta_1^t} \\ \hat{v}_t & =\frac{v_t}{1-\beta_2^t} \\ \theta_{t+1} & =\theta_t-\beta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon} -\lambda\theta_t \end{aligned} \end{split}\]

このAdamWが現在ViTの学習で標準的に用いられているオプティマイザとなる.

AdamWもtorch.optim内で実装されており,次のように利用できる.

optimizer = optim.AdamW(model.parameters())

補足:指数移動平均#

AdamとAdamWの更新式に現れる過去の勾配の一次と二次モーメント

\[\begin{split} \begin{aligned} m_t & =\beta_1 \cdot m_{t-1}+\left(1-\beta_1\right) \cdot g_t \\ v_t & =\beta_2 \cdot v_{t-1}+\left(1-\beta_2\right) \cdot g_t^2 \\ \end{aligned} \end{split}\]

について補足する.まず一次モーメントと二次モーメントは統計学では平均と分散を表しており,ここでは過去の勾配 \(g_t\)\(g_t^2\) を考慮した \(m_t\)\(v_t\) を表している.しかし,Adamの更新式に現れる \(\beta_1\) を使った更新は統計学の説明のモーメントとは異なる.Adamでは,勾配のモーメントを 指数移動平均(Exponential Moving Average, EMA) を使って推定し,その式が上式となる.具体的に,減衰パラメータ \(\beta_1\) を使って,逐次展開してみる.

\[ m_t = \beta_1 g_t + \beta_1 (1 - \beta_1) g_{t-1} + \beta_1 (1 - \beta_1)^2 g_{t-2} + \cdots + \beta_1 (1 - \beta_1)^{t-2} g_2 + \beta_1 (1 - \beta_1)^{t-1} g_1 \]

となり,現在の勾配 \(g_t\) の値を重要視し,過去の勾配の情報は指数的に減衰させたうえで推定に利用している.これが更新式の気持ちである.

またこの指数移動平均は統計量の推定や勾配の推定だけでなく,Mean Teacher と呼ばれる半教師あり学習で用いられるパラメータ更新の手法などでも用いられている普遍的かつ効果的なテクニックである.