データ拡張

データ拡張#

データ拡張(data augmentation) は,モデルの学習データに対して幾何変化など様々な変換を施すことでデータセットの多様性を増やして,過学習を抑制し,モデルの汎化性能を改善させるテクニックである.特に,画像分野においてはこのデータ拡張は必須のテクニックとなっている.画像に対する一般的なデータ拡張は,

  • 画像を左右または上下にランダムに反転する(RandomVerticalFlip/RandomHorizontalFlip

  • 画像を一定の範囲内でランダムに回転させる(RandomRotation

  • 画像の一部領域をランダムに切り取る(RandomResizedCrop

  • 画像の色合いや明るさなどを変更する(ColorJitter

あたりがよく利用される.

データ拡張は torchvision を使うと容易に利用できる.

import matplotlib.pyplot as plt
import numpy as np
from skimage import data
from torchvision import transforms
from PIL import Image

img = Image.fromarray(data.astronaut())

transforms = {
    'Original': None,
    'Vertical Flip': transforms.RandomVerticalFlip(p=1),
    'Horizontal Flip': transforms.RandomHorizontalFlip(p=1),
    'Rotation': transforms.RandomRotation((30, 60)),
    'Resized Crop': transforms.RandomResizedCrop(256, scale=(0.5, 0.7)),
    'Color Jitter': transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2)
}

transformed_images = {'Original': img}
for name, transform in transforms.items():
    if transform is not None:
        transformed_image = transform(img)
        transformed_images[name] = np.array(transformed_image)

fig, axes = plt.subplots(1, len(transformed_images), figsize=(15, 5))
for ax, (name, img) in zip(axes, transformed_images.items()):
    ax.imshow(img)
    ax.set_title(name)
    ax.axis('off')

plt.tight_layout()
../../_images/af518fa7f150be30080247c12eb3884957d0c0b7ef5bf51d15390fd0aafa5cb2.png

データ拡張は一般的に学習データにのみ適用するので,検証や評価データには含めないようにする.また,文字や数字に対して左右反転させるなどデータ拡張によって画像の意味合いが変わる場合には適用してはいけない.また回転など画像が存在しない領域に対しては0で値が埋められることが多い.

強いデータ拡張を適用すると学習が収束しなくなる場合もあるので,もし新しいデータセットを使って新しいモデルを学習させるときは,まずシンプルなデータ拡張のみを使って学習・検証を始めると良い.