


Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notably, it was designed with these principles in mind:

  • Universal: Pyro is a universal PPL -- it can represent any computable probability distribution.
  • Scalable: Pyro scales to large data sets with little overhead compared to hand-written code.
  • Minimal: Pyro is agile and maintainable. It is implemented with a small core of powerful, composable abstractions.
  • Flexible: Pyro aims for automation when you want it, control when you need it. This is accomplished through high-level abstractions to express generative and inference models, while allowing experts easy-access to customize inference.



  • 各ライブラリをimportしておく
%matplotlib inline

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms

import pyro
from pyro.distributions import Normal
from pyro.distributions import Categorical
from pyro.optim import Adam
from pyro.infer import SVI
from pyro.infer import Trace_ELBO

import matplotlib.pyplot as plt
import seaborn as sns


  • 今回はMNISTを実験に用いる
train_loader =
        datasets.MNIST("/tmp/mnist", train=True, download=True,
        batch_size=128, shuffle=True)

test_loader =
        datasets.MNIST("/tmp/mnist", train=False, transform=transforms.Compose([transforms.ToTensor(),])


  • 一般的なPyTorchの作法でモデルの定義を行う
class BNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.out(output)
        return output

net = BNN(28*28, 1024, 10)


  • 確率分布をネットワーク内のすべての重みとバイアスに割り当てることで、確率分布を確率変数に変換
def model(x_data, y_data):
    # define prior destributions
    fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight),
    fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias),
    outw_prior = Normal(loc=torch.zeros_like(net.out.weight),
    outb_prior = Normal(loc=torch.zeros_like(net.out.bias),
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior, 
        'out.weight': outw_prior,
        'out.bias': outb_prior}
    lifted_module = pyro.random_module("module", net, priors)
    lifted_reg_model = lifted_module()
    lhat = F.log_softmax(lifted_reg_model(x_data))
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)

guide functionの定義

  • 事後分布の近似の最適化のための初期化関数を定義しておく
def guide(x_data, y_data):
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = F.softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)

    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = F.softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = F.softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)

    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = F.softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior,
        'out.weight': outw_prior,
        'out.bias': outb_prior}
    lifted_module = pyro.random_module("module", net, priors)
    return lifted_module()


  • optimizerはAdamを用いる
  • SVIはPyroにおける確率的変分推論のための統一的インターフェースとなる
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

Bayesian Neural Networkの学習

  • ネットワークの学習を行う
n_iterations = 5
loss = 0

for j in range(n_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        loss += svi.step(data[0].view(-1,28*28), data[1])

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    print("Epoch ", j, " Loss ", total_epoch_loss_train)
Epoch  0  Loss  2088.199460048882
Epoch  1  Loss  368.3325182894071
Epoch  2  Loss  157.11579643979073
Epoch  3  Loss  110.36873698452314
Epoch  4  Loss  95.78892364851633

Bayesian Neural Networkによる推論

  • Bayesian Neural Networkでは,モデルを複数サンプリングして推論を行う.
n_samples = 10

def predict(x):
    sampled_models = [guide(None, None) for _ in range(n_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images.view(-1,28*28))
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: %d %%" % (100 * correct / total))
  • また,Bayesian Neural Networkの推論結果は確率分布として扱うことができる

  • 確率的推論のための関数を定義

def predict_prob(x):
    sampled_models = [guide(None, None) for _ in range(n_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return mean

def normalize(x):
    return (x - x.min()) / (x.max() - x.min())
  • 可視化のための関数を定義
def plot(x, yhats):
    fig, (axL, axR) = plt.subplots(ncols=2, figsize=(8, 4))[i for i in range(10)], height= F.softmax(torch.Tensor(normalize(yhats.numpy()))[0]))
    axL.set_xticks([i for i in range(10)], [i for i in range(10)])


  • 正解ラベルである"7"の予測確率が最も高くなっている
x, y = test_loader.dataset[0]
yhats = predict_prob(x.view(-1, 28*28))
print("ground truth: ", y.item())
print("predicted: ", yhats.numpy())
plot(x, yhats)
ground truth:  7
predicted:  [[  39.87391   -86.72453   112.25181   102.81404   -89.95473    52.947186
  -167.88455   402.66757   -29.799454   95.75331 ]]


  • こちらも,正解ラベルである"0"の予測確率が最も高くなっている
x, y = test_loader.dataset[3]
yhats = predict_prob(x.view(-1, 28*28))
print("ground truth: ", y.item())
print("predicted: ", yhats.numpy())
plot(x, yhats)
ground truth:  0
predicted:  [[ 344.0459   -145.05759    41.21102   -51.52296  -123.594765   14.442251
   101.51439    30.152363  -49.83236   -46.353443]]


  • 正解ラベルである"9"と同程度に"7"の確率も高くなっている
  • 今回は学習イテレーション数がかなり少ないため,もっと学習を回すことでより良い結果が得られると思われる
x, y = test_loader.dataset[20]
yhats = predict_prob(x.view(-1, 28*28))
print("ground truth: ", y.item())
print("predicted: ", yhats.numpy())
plot(x, yhats)
ground truth:  9
predicted:  [[ -82.09069   -80.67144   -44.980133   96.68532    72.79546    44.7148
  -148.41635   182.37096    57.05369   231.20334 ]]
