PyTorchでTorch Hubを使ったpre-trained modelの共有

Torch Hubは研究の再現性のために,PyTorchのpre-trained modelの共有のための機能.

hubconf.py

  • 以下のフォーマットでentrypointを定義する
  • entrypointはPyTorchのモデルを返す関数
def entrypoint_name(pretrained=False, *args, **kwargs):
    ...
import torch.utils.model_zoo as model_zoo

# Optional list of dependencies required by the package
dependencies = ['torch', 'math']


def resnet18(pretrained=False, *args, **kwargs):
    """
    Resnet18 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    from torchvision.models.resnet import resnet18 as _resnet18
    model = _resnet18(*args, **kwargs)
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    if pretrained:
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model


def resnet50(pretrained=False, *args, **kwargs):
    """
    Resnet50 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    from torchvision.models.resnet import resnet50 as _resnet50
    model = _resnet50(*args, **kwargs)
    checkpoint = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    if pretrained:
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model

モデルのロード

  • PyTorch Hubのモデルは以下のようにロードできる
  • 'pytorch/vision:master'はhubconf.pyが置いてあるGitHubレポジトリのオーナー名/レポジトリ名:ブランチ名
hub_model = hub.load(
    'pytorch/vision:master', # repo_owner/repo_name:branch
    'resnet18', # entrypoint
    1234, # args for callable [not applicable to resnet]
    pretrained=True) # kwargs for callable

実験

%matplotlib inline
import pickle
import urllib
from PIL import Image
import torch.hub as hub
from torchvision import transforms

import matplotlib.pyplot as plt

hub_model = hub.load(
    'pytorch/vision:master',
    'resnet18',
    pretrained=True)

labels = pickle.load(urllib.request.urlopen("https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl"))

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
test_image = Image.open("../images/dog.jpg")
test_image_tensor = (transform((test_image))).unsqueeze(dim=0)

plt.imshow(test_image)

pred_idx = hub_model(test_image_tensor).max(1)[1]
print("pred: ", labels[int(pred_idx)])
pred:  guinea pig, Cavia cobaya

f:id:noconocolib:20190105043454p:plain