PyTorchでConditional GANs

GANsの学習フレームワークの中に明示的にラベル情報を入れるConditional GANsをPyTorchで実装してみる.

  • Conditional Generative Adversarial Nets
  • paper link

モジュールのインポート

  • 必要なライブラリをimportしておく
  • ちなみにjupyter notebookの場合は from tqdm import tqdm_notebook as tqdmしておくと正常にプログレスバーが表示される
%matplotlib inline
from tqdm import tqdm_notebook as tqdm
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

データセット

  • 今回の実験ではFashionMNISTを用いる
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

trainset = FashionMNIST(download=True, train=True, transform=transform, root="/tmp/data")
testset = FashionMNIST(download=True, train=False, transform=transform, root="/tmp/data")

train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Processing...
Done!

ネットワークの定義

  • 画像を生成するGeneratorと分類するDiscriminatorを定義する
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.label_emb = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        x = x.view(x.size(0), 784)
        c = self.label_emb(labels)
        x = torch.cat([x, c], 1)
        out = self.model(x)
        return out.squeeze()
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.label_emb = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        z = z.view(z.size(0), 100)
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        out = self.model(x)
        return out.view(x.size(0), 28, 28)
generator = Generator()
discriminator = Discriminator()

損失関数とオプティマイザの定義

  • 損失関数はbinary cross entropy,オプティマイザはAdamを用いる
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

学習フレームワーク

def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad()
    z = torch.randn(batch_size, 100)
    fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size))
    fake_images = generator(z, fake_labels)
    validity = discriminator(fake_images, fake_labels)
    g_loss = criterion(validity, torch.ones(batch_size))
    g_loss.backward()
    g_optimizer.step()
    return g_loss.data[0]
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    d_optimizer.zero_grad()

    real_validity = discriminator(real_images, labels)
    real_loss = criterion(real_validity, torch.ones(batch_size))
    
    z = torch.randn(batch_size, 100)
    fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size))
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    fake_loss = criterion(fake_validity, torch.zeros(batch_size))
    
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss.data[0]

学習

num_epochs = 30

for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch))
    for i, (images, labels) in tqdm(enumerate(train_loader)):
        real_images = images
        generator.train()
        batch_size = real_images.size(0)
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        
        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)

    generator.eval()
    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    z = torch.randn(9, 100)
    labels = torch.LongTensor(np.arange(9))
    sample_images = generator(z, labels).unsqueeze(1).data.cpu()
    grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()

f:id:noconocolib:20190102181303p:plain
[epoch 1] g_loss: 2.1179, d_loss: 0.4485

f:id:noconocolib:20190102181427p:plain
[epoch 4] g_loss: 2.1488, d_loss: 0.5789

f:id:noconocolib:20190102181514p:plain
[epoch 7] g_loss: 1.778, d_loss: 0.9159

学習後の生成画像

generator.eval()
print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
z = torch.randn(40, 100)
labels = torch.LongTensor(np.random.randint(0, 10, 40))
sample_images = generator(z, labels).unsqueeze(1).data.cpu()
grid = make_grid(sample_images, nrow=8, normalize=True).permute(1,2,0).numpy()

plt.figure(figsize=(8, 10))
plt.imshow(grid)
plt.show()

f:id:noconocolib:20190102190259p:plain
g_loss: 1.0171, d_loss: 1.2502