Torch Hubは研究の再現性のために,PyTorchのpre-trained modelの共有のための機能.
hubconf.py
- 以下のフォーマットでentrypointを定義する
- entrypointはPyTorchのモデルを返す関数
def entrypoint_name(pretrained=False, *args, **kwargs): ...
- pytorch/visionのresnetモデルのhubconf.pyは以下のように書かれている
import torch.utils.model_zoo as model_zoo # Optional list of dependencies required by the package dependencies = ['torch', 'math'] def resnet18(pretrained=False, *args, **kwargs): """ Resnet18 model pretrained (bool): a recommended kwargs for all entrypoints args & kwargs are arguments for the function """ from torchvision.models.resnet import resnet18 as _resnet18 model = _resnet18(*args, **kwargs) checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' if pretrained: model.load_state_dict(model_zoo.load_url(checkpoint, progress=False)) return model def resnet50(pretrained=False, *args, **kwargs): """ Resnet50 model pretrained (bool): a recommended kwargs for all entrypoints args & kwargs are arguments for the function """ from torchvision.models.resnet import resnet50 as _resnet50 model = _resnet50(*args, **kwargs) checkpoint = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' if pretrained: model.load_state_dict(model_zoo.load_url(checkpoint, progress=False)) return model
モデルのロード
- PyTorch Hubのモデルは以下のようにロードできる
'pytorch/vision:master'
はhubconf.pyが置いてあるGitHubレポジトリのオーナー名/レポジトリ名:ブランチ名
hub_model = hub.load( 'pytorch/vision:master', # repo_owner/repo_name:branch 'resnet18', # entrypoint 1234, # args for callable [not applicable to resnet] pretrained=True) # kwargs for callable
実験
%matplotlib inline import pickle import urllib from PIL import Image import torch.hub as hub from torchvision import transforms import matplotlib.pyplot as plt hub_model = hub.load( 'pytorch/vision:master', 'resnet18', pretrained=True) labels = pickle.load(urllib.request.urlopen("https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl")) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) test_image = Image.open("../images/dog.jpg") test_image_tensor = (transform((test_image))).unsqueeze(dim=0) plt.imshow(test_image) pred_idx = hub_model(test_image_tensor).max(1)[1] print("pred: ", labels[int(pred_idx)])
pred: guinea pig, Cavia cobaya