事前学習済みモデル

事前学習済みモデル#

事前学習(pretraining) はあるタスクに対して事前にモデルを訓練することを示し,その事前学習されたモデルを 事前学習済みモデル(pretrained model) と呼ぶ.

この事前学習モデルの重みを初期値として,別の新規タスクを解けるように再学習することを Fine-tuning と呼び,これはあるタスクから別のタスクへの知識転移であるので転移学習の一種と位置付けられる.

事前学習とFine-tuningの用語はセットで用いられることが多く,画像認識を例にすると,ImageNetなどの大規模データセットを分類損失で事前学習し,その重みを使って別の画像分類タスク(これは小規模データセットである場合が多い)を解くという流れとなる.

PyTorchでの事前学習済みモデルの利用は torchvision または timm を使うと良い.非常に簡単に利用できる.

ここでは torchvision を使って事前学習済みモデルを読み込み,予測に利用してみる.まずは重みとモデル構造を読み込む.

from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

初回実行時は重みをダウンロードするために時間がかかるが,以上の二行で利用は完了である.ただし,モデル構造が異なる場合は読み込めないので注意が必要である.

続いて,入力画像を用意する.

from skimage.data import chelsea
from PIL import Image

img = chelsea()
img = Image.fromarray(img)
img
../../_images/7e7123c661a78b7c204ace4df6b2fc3fff31d1ac4096c1bac97027706aa8bb57.png

続いて,入力画像に前処理を施す.ここで注意であるが,事前学習と同様の前処理を施す必要がある.例えば,入力画像を学習データセットの統計量で標準化している場合,推論時も同じく同様の統計量で標準化する必要がある.

torchvisiontimm では,必要な前処理は次のようにロードできるが,整備されていない事前学習済み重みや自前で事前学習した重みの利用時などは注意されたい.

preprocess = weights.transforms()
preprocess
ImageClassification(
    crop_size=[224]
    resize_size=[232]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

前処理用の transform が読み込めたので,ミニバッチの次元を追加して前処理を施す

x = preprocess(img).unsqueeze(0)
print('x.shape:', x.shape)
x.shape: torch.Size([1, 3, 224, 224])

重み自体の読み込みは前述のセルですでに完了しているので予測をする.

model.eval()
prediction = model(x).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]

print(f"{category_name}: {100 * score:.1f}%")
Egyptian cat: 34.2%

torchvisionweights.meta["categories"] で学習時のクラスなどのメタデータにアクセスできる.

今回は予測のみを扱ったが,モデルの作成時に重みを指定するだけなので,Fine-tuningも容易に行える.事前学習済みモデルは,データや計算リソースが限られている場合でも高精度なモデルを迅速に構築できる強力な手法であり,一からモデルを設計するよりも,事前学習済みモデルを利用した方が精度が良い場合が多い.