Convolutional Neural Network

Convolutional Neural Network#

畳み込みニューラルネットワーク(Convolutional Neural Network, CNN) は,主に画像データに対して高い性能を発揮する順伝播型ニューラルネットワークであり,畳み込み層とPooling層を複数積み重ねることで局所的な特徴から大域的な特徴を抽出し,その特徴から最終的に線形層から分類や回帰を行う.

このノートブックでは,シンプルなCNNを構築し,MLPと同様にダミーデータを用いて順伝播を計算する.

PyTorchでは,CNNに関しても,nn.Module を使ってニューラルネットワークの構造をクラスとして次のように定義することが一般的である.今回は,RGB画像を想定した3チャネルの画像を入力として受け取り,2回の畳み込みとPoolingを行い,特徴マップをベクトル化して,線形層に入力する構造とする.また畳み込みの後にBatch Normalizationを適用する.

これは次のように実装できる.

import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = self.pool2(x)
        x = self.gap(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

これを in_channels=3num_classes=10 としてインスタンス化する.

in_channels = 3
num_classes = 10
model = CNN(in_channels=in_channels, num_classes=num_classes)

モデルの構造を確認する.

print(model)
CNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gap): AdaptiveAvgPool2d(output_size=1)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc): Linear(in_features=32, out_features=10, bias=True)
)

今回は活性化関数を nn.functional.relu で作成した.これは,nn.ReLU とは異なり,インスタンス化が不要な関数であり,関数として直接呼び出すことができる.

nn.ReLU は明示的にモデル内部に活性化関数を定義することができ,nn.ReLU(inplace=True) とするとメモリ効率を向上することができるが,基本的にはどちらの実装でも良い.

続いて,MLPと同様に正しく順伝播できるかどうかをダミーデータを与えてチェックしよう.今回は,\(32 \times 32\) サイズの入力を考える.

batch_size = 3
height, width = 32, 32
input_size = (batch_size, in_channels, height, width)
dummy_input = torch.randn(*input_size)

print('dummy_input.shape:', dummy_input.shape)
print('dummy_input = ')
print(dummy_input)
dummy_input.shape: torch.Size([3, 3, 32, 32])
dummy_input = 
tensor([[[[-2.2338e+00,  6.7255e-01,  5.1815e-01,  ..., -7.4766e-01,
            1.0627e+00,  9.5197e-01],
          [ 1.1238e+00,  1.4813e-01, -1.6155e-01,  ...,  1.6468e+00,
            2.2724e-01,  1.0323e+00],
          [-1.1442e+00, -1.3405e+00, -1.6309e-01,  ...,  2.1085e+00,
            5.1301e-01, -1.8134e+00],
          ...,
          [-1.3806e-01, -1.0085e-01,  1.2108e+00,  ..., -2.2120e-01,
            2.1932e-01,  2.6016e-01],
          [ 3.5698e-01, -1.0573e+00, -3.7467e-01,  ...,  2.2146e-01,
            2.4294e-01,  4.1672e-01],
          [-1.4614e+00, -2.7861e+00,  7.2765e-01,  ..., -1.3933e-01,
            1.9755e-01,  2.0388e-01]],

         [[-2.0519e-01,  8.5137e-01, -2.6373e-01,  ..., -7.6838e-01,
            3.8190e-01,  9.5757e-03],
          [-4.0491e-01, -5.5069e-01, -1.3497e+00,  ...,  1.2793e-01,
           -7.0653e-01,  7.6806e-01],
          [-1.0431e+00,  1.5969e+00,  8.7825e-01,  ...,  2.3754e-02,
           -8.7329e-01, -3.2425e-01],
          ...,
          [ 6.1635e-02, -6.3180e-01,  5.7092e-02,  ..., -1.5198e+00,
            1.0524e+00, -6.2341e-01],
          [ 1.1562e+00, -1.4519e+00,  1.3175e+00,  ...,  3.4146e-01,
           -9.2581e-01,  1.0007e+00],
          [ 8.5909e-01,  6.8983e-01, -1.2605e+00,  ..., -4.3246e-01,
            2.4352e+00,  6.7932e-01]],

         [[-7.8446e-01,  8.7600e-01, -2.6004e+00,  ...,  5.2872e-01,
           -3.3704e-01, -9.4218e-01],
          [ 1.5532e+00,  1.7289e+00, -1.0404e-01,  ...,  3.4349e-01,
           -8.5733e-01,  5.3544e-01],
          [-1.6537e+00, -3.4990e-01, -6.5155e-01,  ..., -1.2528e+00,
            1.8858e+00, -6.2303e-01],
          ...,
          [-1.5285e-01,  1.4243e-01,  1.9470e+00,  ...,  8.6001e-02,
            1.5704e+00, -1.0511e+00],
          [-2.7819e+00, -1.0021e+00,  1.5221e-01,  ...,  4.5148e-01,
            9.0905e-01,  1.2654e-01],
          [ 2.3395e-01, -5.8103e-02,  7.1578e-01,  ...,  2.9437e-01,
           -1.3657e+00,  7.3239e-01]]],


        [[[ 5.5136e-02, -1.2081e+00,  6.4435e-01,  ..., -5.6562e-01,
            1.1528e+00, -9.2269e-02],
          [-3.2532e-01, -1.1193e+00, -2.0573e-01,  ...,  1.3423e+00,
           -9.1528e-01, -1.1104e+00],
          [ 4.9955e-01, -2.4621e+00,  7.8787e-01,  ..., -1.7502e+00,
            4.3552e-01, -8.0079e-01],
          ...,
          [ 7.9871e-01, -4.0601e-01,  1.2308e+00,  ...,  2.2081e-01,
            5.0757e-01, -2.5719e-01],
          [ 1.1418e+00,  4.1416e-02,  1.8281e+00,  ...,  3.3097e+00,
            1.1592e+00,  1.4956e+00],
          [ 3.4046e-01, -4.2337e-02,  1.9563e+00,  ...,  5.8742e-01,
           -7.2247e-01,  1.1070e+00]],

         [[ 1.7086e-01,  1.3962e-01,  6.3575e-01,  ..., -1.8086e-02,
           -1.9898e-01,  4.1100e-01],
          [ 7.0437e-01, -3.3213e-01, -1.4582e+00,  ...,  7.5942e-01,
           -7.4249e-01,  1.8627e-01],
          [ 4.8769e-01,  7.5126e-01, -5.3122e-01,  ..., -5.4504e-01,
            2.3885e-01,  1.1886e+00],
          ...,
          [-1.3740e+00,  1.2950e+00, -2.0269e+00,  ..., -4.6823e-01,
            1.3552e+00, -1.5352e+00],
          [ 1.9183e+00, -7.0725e-01,  1.1388e+00,  ...,  1.3259e+00,
            6.1337e-01,  2.5681e-01],
          [-1.1569e+00,  2.4299e-01,  5.6252e-01,  ...,  7.4827e-01,
           -1.2315e+00, -6.2925e-01]],

         [[ 5.2922e-01,  1.2870e-01, -4.5965e-01,  ...,  2.0628e+00,
           -1.6044e+00, -1.0532e+00],
          [ 8.2605e-02,  1.1014e+00, -2.7432e-01,  ..., -2.5177e-06,
            1.0814e+00, -3.5536e-01],
          [ 4.6173e-01, -4.7356e-01,  8.1936e-01,  ...,  1.8746e-01,
           -1.0123e-01,  2.0382e+00],
          ...,
          [ 7.6882e-01, -9.9525e-01, -1.6357e+00,  ...,  7.2419e-01,
           -3.4926e-01, -8.0924e-01],
          [ 7.1890e-01, -1.0214e+00,  7.4638e-01,  ...,  9.4148e-02,
           -8.2216e-01, -9.7149e-02],
          [-1.1805e+00,  9.2248e-02,  5.8988e-01,  ..., -1.0943e+00,
           -7.5587e-01,  2.8965e-01]]],


        [[[-1.6354e+00, -5.0211e-01,  9.9580e-01,  ...,  1.1006e+00,
           -1.3054e+00, -1.1822e+00],
          [-1.9376e+00,  2.3668e-01, -2.6948e-01,  ..., -1.3834e+00,
            4.5937e-01,  1.7955e-01],
          [-4.7553e-01,  4.6342e-01, -2.3756e-01,  ..., -9.0949e-01,
           -1.2622e+00, -2.6954e-02],
          ...,
          [-4.8785e-01,  5.3450e-01,  8.9168e-01,  ...,  2.9505e-01,
            4.0803e-01, -5.2234e-01],
          [-2.8476e-01, -1.6570e+00,  8.3642e-01,  ...,  1.3755e+00,
            2.4468e+00, -8.5187e-01],
          [-1.8817e+00, -7.5856e-01,  1.2030e+00,  ..., -1.4585e+00,
           -6.3247e-01,  4.1942e-01]],

         [[ 7.0672e-01, -5.9478e-01, -2.4069e+00,  ...,  1.3060e+00,
           -1.2666e+00, -9.5439e-03],
          [-4.3280e-01, -4.5998e-01,  1.7029e+00,  ..., -1.9753e-02,
           -1.0294e-01,  1.2070e+00],
          [-1.4817e+00, -5.3051e-03, -2.0739e+00,  ...,  2.9199e-01,
           -2.2163e-01,  3.4185e-01],
          ...,
          [-7.0397e-01,  7.5678e-04, -7.9128e-01,  ...,  7.7903e-01,
           -2.1894e-01,  9.6161e-01],
          [ 1.6249e+00, -1.2602e+00, -1.7916e+00,  ..., -1.9309e+00,
           -1.3579e+00, -1.6807e+00],
          [-1.6074e+00,  4.6682e-01,  2.0595e+00,  ...,  1.0084e+00,
           -1.0570e-01,  1.1311e+00]],

         [[-4.4626e-01, -7.0167e-02, -6.8692e-01,  ..., -1.8823e+00,
           -1.1833e+00,  1.9147e+00],
          [ 3.2459e-01, -8.2440e-02,  1.3713e+00,  ...,  1.3211e+00,
            3.7216e-01, -1.6952e+00],
          [-1.0443e+00,  2.5704e-01, -7.2021e-01,  ...,  1.5119e+00,
           -3.9108e-01, -1.0278e-02],
          ...,
          [-6.3837e-01, -1.0105e+00, -1.1723e+00,  ..., -5.8328e-01,
            1.7788e+00,  4.9207e-01],
          [-3.2711e-02, -3.4311e-01,  1.1519e+00,  ..., -4.6279e-01,
            2.6106e+00,  5.1771e-01],
          [ 1.7222e-02,  9.5191e-01,  1.1658e+00,  ...,  1.3847e+00,
            5.7338e-01, -2.5108e-01]]]])

バッチサイズも含めて4回のテンソルが作成できた.順伝播を行い,(バッチサイズ, クラス数)の出力が得られることを確認する.

output = model(dummy_input)

print('output.shape: ', output.shape)
print('output = ')
print(output)
output.shape:  torch.Size([3, 10])
output = 
tensor([[-0.4440,  0.0198, -0.0612,  0.0512,  0.0562,  0.1882,  0.9370, -0.1396,
          0.5335, -0.3664],
        [-0.5024,  0.1678, -0.0761, -0.0162,  0.0978,  0.2067,  0.9325, -0.1004,
          0.5247, -0.3562],
        [-0.4934,  0.0790, -0.0318,  0.0329, -0.0258,  0.1461,  0.9660, -0.1850,
          0.6089, -0.4052]], grad_fn=<AddmmBackward0>)

このときエラーが発生する場合は forward 関数内の処理(特に,ベクトル化の処理)か畳み込みの入出力チャネル数を間違えていることが多い.

層へのアクセスやパラメータ数の取得方法も確認しておく.基本的にはMLPの場合と同じである.

total_params = 0
for name, param in model.named_parameters():
    print(f'{name} shape: {param.shape}')
    total_params += param.numel()

print(f'Total number of trainable parameters: {total_params}')
conv1.weight shape: torch.Size([16, 3, 3, 3])
conv1.bias shape: torch.Size([16])
bn1.weight shape: torch.Size([16])
bn1.bias shape: torch.Size([16])
conv2.weight shape: torch.Size([32, 16, 3, 3])
conv2.bias shape: torch.Size([32])
bn2.weight shape: torch.Size([32])
bn2.bias shape: torch.Size([32])
fc.weight shape: torch.Size([10, 32])
fc.bias shape: torch.Size([10])
Total number of trainable parameters: 5514