PyTorchと,確率的プログラミングフレームワークであるPyroを用いてベイジアンニューラルネットワークを試してみる.
Pyro
Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notably, it was designed with these principles in mind:
- Universal: Pyro is a universal PPL -- it can represent any computable probability distribution.
- Scalable: Pyro scales to large data sets with little overhead compared to hand-written code.
- Minimal: Pyro is agile and maintainable. It is implemented with a small core of powerful, composable abstractions.
- Flexible: Pyro aims for automation when you want it, control when you need it. This is accomplished through high-level abstractions to express generative and inference models, while allowing experts easy-access to customize inference.
次のような流れに沿うことで,PyTorchで書かれたニューラルネットワークをベイジアンニューラルネットワークとして扱うことができる:
モジュールのimport
- 各ライブラリをimportしておく
%matplotlib inline import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets from torchvision import transforms import pyro from pyro.distributions import Normal from pyro.distributions import Categorical from pyro.optim import Adam from pyro.infer import SVI from pyro.infer import Trace_ELBO import matplotlib.pyplot as plt import seaborn as sns sns.set()
データセットの用意
- 今回はMNISTを実験に用いる
train_loader = torch.utils.data.DataLoader( datasets.MNIST("/tmp/mnist", train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),])), batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST("/tmp/mnist", train=False, transform=transforms.Compose([transforms.ToTensor(),])
モデルの定義
- 一般的なPyTorchの作法でモデルの定義を行う
class BNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(BNN, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.out = nn.Linear(hidden_size, output_size) def forward(self, x): output = self.fc1(x) output = F.relu(output) output = self.out(output) return output net = BNN(28*28, 1024, 10)
確率分布を重みとバイアスに割り当て
- 確率分布をネットワーク内のすべての重みとバイアスに割り当てることで、確率分布を確率変数に変換
def model(x_data, y_data): # define prior destributions fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight)) fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias)) outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight)) outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias)) priors = { 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior} lifted_module = pyro.random_module("module", net, priors) lifted_reg_model = lifted_module() lhat = F.log_softmax(lifted_reg_model(x_data)) pyro.sample("obs", Categorical(logits=lhat), obs=y_data)
guide functionの定義
- 事後分布の近似の最適化のための初期化関数を定義しておく
def guide(x_data, y_data): fc1w_mu = torch.randn_like(net.fc1.weight) fc1w_sigma = torch.randn_like(net.fc1.weight) fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu) fc1w_sigma_param = F.softplus(pyro.param("fc1w_sigma", fc1w_sigma)) fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param) fc1b_mu = torch.randn_like(net.fc1.bias) fc1b_sigma = torch.randn_like(net.fc1.bias) fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu) fc1b_sigma_param = F.softplus(pyro.param("fc1b_sigma", fc1b_sigma)) fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param) outw_mu = torch.randn_like(net.out.weight) outw_sigma = torch.randn_like(net.out.weight) outw_mu_param = pyro.param("outw_mu", outw_mu) outw_sigma_param = F.softplus(pyro.param("outw_sigma", outw_sigma)) outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1) outb_mu = torch.randn_like(net.out.bias) outb_sigma = torch.randn_like(net.out.bias) outb_mu_param = pyro.param("outb_mu", outb_mu) outb_sigma_param = F.softplus(pyro.param("outb_sigma", outb_sigma)) outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param) priors = { 'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior} lifted_module = pyro.random_module("module", net, priors) return lifted_module()
Optimizerの定義
- optimizerはAdamを用いる
- SVIはPyroにおける確率的変分推論のための統一的インターフェースとなる
optim = Adam({"lr": 0.01}) svi = SVI(model, guide, optim, loss=Trace_ELBO())
Bayesian Neural Networkの学習
- ネットワークの学習を行う
n_iterations = 5 loss = 0 for j in range(n_iterations): loss = 0 for batch_id, data in enumerate(train_loader): loss += svi.step(data[0].view(-1,28*28), data[1]) normalizer_train = len(train_loader.dataset) total_epoch_loss_train = loss / normalizer_train print("Epoch ", j, " Loss ", total_epoch_loss_train)
Epoch 0 Loss 2088.199460048882 Epoch 1 Loss 368.3325182894071 Epoch 2 Loss 157.11579643979073 Epoch 3 Loss 110.36873698452314 Epoch 4 Loss 95.78892364851633
Bayesian Neural Networkによる推論
- Bayesian Neural Networkでは,モデルを複数サンプリングして推論を行う.
n_samples = 10 def predict(x): sampled_models = [guide(None, None) for _ in range(n_samples)] yhats = [model(x).data for model in sampled_models] mean = torch.mean(torch.stack(yhats), 0) return np.argmax(mean.numpy(), axis=1) print('Prediction when network is forced to predict') correct = 0 total = 0 for j, data in enumerate(test_loader): images, labels = data predicted = predict(images.view(-1,28*28)) total += labels.size(0) correct += (predicted == labels).sum().item() print("accuracy: %d %%" % (100 * correct / total))
また,Bayesian Neural Networkの推論結果は確率分布として扱うことができる
確率的推論のための関数を定義
def predict_prob(x): sampled_models = [guide(None, None) for _ in range(n_samples)] yhats = [model(x).data for model in sampled_models] mean = torch.mean(torch.stack(yhats), 0) return mean def normalize(x): return (x - x.min()) / (x.max() - x.min())
- 可視化のための関数を定義
def plot(x, yhats): fig, (axL, axR) = plt.subplots(ncols=2, figsize=(8, 4)) axL.bar(x=[i for i in range(10)], height= F.softmax(torch.Tensor(normalize(yhats.numpy()))[0])) axL.set_xticks([i for i in range(10)], [i for i in range(10)]) axR.imshow(x.numpy()[0]) plt.show()
推論と可視化
- 正解ラベルである"7"の予測確率が最も高くなっている
x, y = test_loader.dataset[0] yhats = predict_prob(x.view(-1, 28*28)) print("ground truth: ", y.item()) print("predicted: ", yhats.numpy()) plot(x, yhats)
ground truth: 7 predicted: [[ 39.87391 -86.72453 112.25181 102.81404 -89.95473 52.947186 -167.88455 402.66757 -29.799454 95.75331 ]]
- こちらも,正解ラベルである"0"の予測確率が最も高くなっている
x, y = test_loader.dataset[3] yhats = predict_prob(x.view(-1, 28*28)) print("ground truth: ", y.item()) print("predicted: ", yhats.numpy()) plot(x, yhats)
ground truth: 0 predicted: [[ 344.0459 -145.05759 41.21102 -51.52296 -123.594765 14.442251 101.51439 30.152363 -49.83236 -46.353443]]
- 正解ラベルである"9"と同程度に"7"の確率も高くなっている
- 今回は学習イテレーション数がかなり少ないため,もっと学習を回すことでより良い結果が得られると思われる
x, y = test_loader.dataset[20] yhats = predict_prob(x.view(-1, 28*28)) print("ground truth: ", y.item()) print("predicted: ", yhats.numpy()) plot(x, yhats)
ground truth: 9 predicted: [[ -82.09069 -80.67144 -44.980133 96.68532 72.79546 44.7148 -148.41635 182.37096 57.05369 231.20334 ]]