GANsの学習フレームワークの中に明示的にラベル情報を入れるConditional GANsをPyTorchで実装してみる.
- Conditional Generative Adversarial Nets
- paper link
モジュールのインポート
- 必要なライブラリをimportしておく
- ちなみにjupyter notebookの場合は
from tqdm import tqdm_notebook as tqdm
しておくと正常にプログレスバーが表示される
%matplotlib inline from tqdm import tqdm_notebook as tqdm import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn from torchvision import transforms from torchvision.datasets import FashionMNIST from torchvision.utils import make_grid import matplotlib.pyplot as plt import seaborn as sns sns.set()
データセット
- 今回の実験ではFashionMNISTを用いる
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) trainset = FashionMNIST(download=True, train=True, transform=transform, root="/tmp/data") testset = FashionMNIST(download=True, train=False, transform=transform, root="/tmp/data") train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Processing... Done!
ネットワークの定義
- 画像を生成するGeneratorと分類するDiscriminatorを定義する
class Discriminator(nn.Module): def __init__(self): super().__init__() self.label_emb = nn.Embedding(10, 10) self.model = nn.Sequential( nn.Linear(794, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(1024, 512), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x, labels): x = x.view(x.size(0), 784) c = self.label_emb(labels) x = torch.cat([x, c], 1) out = self.model(x) return out.squeeze()
class Generator(nn.Module): def __init__(self): super().__init__() self.label_emb = nn.Embedding(10, 10) self.model = nn.Sequential( nn.Linear(110, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, z, labels): z = z.view(z.size(0), 100) c = self.label_emb(labels) x = torch.cat([z, c], 1) out = self.model(x) return out.view(x.size(0), 28, 28)
generator = Generator() discriminator = Discriminator()
損失関数とオプティマイザの定義
- 損失関数はbinary cross entropy,オプティマイザはAdamを用いる
criterion = nn.BCELoss() d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4) g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
学習フレームワーク
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion): g_optimizer.zero_grad() z = torch.randn(batch_size, 100) fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)) fake_images = generator(z, fake_labels) validity = discriminator(fake_images, fake_labels) g_loss = criterion(validity, torch.ones(batch_size)) g_loss.backward() g_optimizer.step() return g_loss.data[0]
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels): d_optimizer.zero_grad() real_validity = discriminator(real_images, labels) real_loss = criterion(real_validity, torch.ones(batch_size)) z = torch.randn(batch_size, 100) fake_labels = torch.LongTensor(np.random.randint(0, 10, batch_size)) fake_images = generator(z, fake_labels) fake_validity = discriminator(fake_images, fake_labels) fake_loss = criterion(fake_validity, torch.zeros(batch_size)) d_loss = real_loss + fake_loss d_loss.backward() d_optimizer.step() return d_loss.data[0]
学習
num_epochs = 30 for epoch in range(num_epochs): print('Starting epoch {}...'.format(epoch)) for i, (images, labels) in tqdm(enumerate(train_loader)): real_images = images generator.train() batch_size = real_images.size(0) d_loss = discriminator_train_step(len(real_images), discriminator, generator, d_optimizer, criterion, real_images, labels) g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion) generator.eval() print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss)) z = torch.randn(9, 100) labels = torch.LongTensor(np.arange(9)) sample_images = generator(z, labels).unsqueeze(1).data.cpu() grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy() plt.imshow(grid) plt.show()
学習後の生成画像
generator.eval() print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss)) z = torch.randn(40, 100) labels = torch.LongTensor(np.random.randint(0, 10, 40)) sample_images = generator(z, labels).unsqueeze(1).data.cpu() grid = make_grid(sample_images, nrow=8, normalize=True).permute(1,2,0).numpy() plt.figure(figsize=(8, 10)) plt.imshow(grid) plt.show()