PyTorchの自作モデルをTorch Hubに登録してみる.
hubconf.pyの設置
- 以下のようにhubconf.pyをレポジトリの直下に設置する.
# -*- coding: utf-8 -*- from models import MobileNet_v2 def mobilenet_v2(pretrained=False, *args, **kwargs): model = MobileNet_v2() if pretrained: raise NotImplementedError return model
- 今回は学習済みの重みは作ってないので,pretrained == Trueの時はエラーを出すようにしておく
- 実際に学習済みモデルを持っている時は
if pretrained:
の部分で重みをロードするようにすれば良い
自作モデルの読み込み
- GitHubレポジトリ上にhubconf.pyを設置したら,
torch.hub
モジュールによってモデルのロードができるようになっている
%matplotlib inline import torch.hub as hub hub_model = hub.load( 'nocotan/lightweight_models.torch:master', 'mobilenet_v2', pretrained=False) print(hub_model)
MobileNet_v2( (features): Sequential( (0): ConvBlock( (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvResBlock( (conv): Sequential( (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace) (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): InvResBlock( (conv): Sequential( (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU6(inplace) (3): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False) (4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU6(inplace) (6): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (7): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ...
- 正常にモデルの読み込みができている
- 試しに
pretrained=True
にしてみると,期待通りにエラーが出る
hub_model = hub.load( 'nocotan/lightweight_models.torch:master', 'mobilenet_v2', pretrained=True) print(hub_model)