PyTorchでTorch Hubに自作モデルの登録

PyTorchの自作モデルをTorch Hubに登録してみる.

hubconf.pyの設置

  • 以下のようにhubconf.pyをレポジトリの直下に設置する.
# -*- coding: utf-8 -*-
from models import MobileNet_v2


def mobilenet_v2(pretrained=False, *args, **kwargs):
    model = MobileNet_v2()

    if pretrained:
        raise NotImplementedError

    return model
  • 今回は学習済みの重みは作ってないので,pretrained == Trueの時はエラーを出すようにしておく
  • 実際に学習済みモデルを持っている時はif pretrained:の部分で重みをロードするようにすれば良い

自作モデルの読み込み

  • GitHubレポジトリ上にhubconf.pyを設置したら,torch.hubモジュールによってモデルのロードができるようになっている
%matplotlib inline
import torch.hub as hub

hub_model = hub.load(
    'nocotan/lightweight_models.torch:master',
    'mobilenet_v2',
    pretrained=False)

print(hub_model)
MobileNet_v2(
  (features): Sequential(
    (0): ConvBlock(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): InvResBlock(
      (conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace)
        (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvResBlock(
      (conv): Sequential(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace)
        (3): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
        (4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU6(inplace)
        (6): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (7): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )

...
  • 正常にモデルの読み込みができている
  • 試しにpretrained=Trueにしてみると,期待通りにエラーが出る
hub_model = hub.load(
    'nocotan/lightweight_models.torch:master',
    'mobilenet_v2',
    pretrained=True)

print(hub_model)