PyTorchでAdversarial Attack. モルモットをインドゾウに誤認識させる

PyTorchを用いて分類器に対する攻撃手法であるAdversarial Attackを実装してみる. これは,分類器に対して故意に誤分類を誘発させるような画像を生成する攻撃手法である.例えば,

  • 自動運転車に対する標識の誤検出の誘発
  • 顔認識システムの第三者による誤認証

など,ニューラルネットの社会実装をする上で重要な問題である.

Import Modules

  • 必要なライブラリをimportしておく
%matplotlib inline
import json
import pickle
import urllib
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.gradcheck import zero_gradients
from torchvision import models
from torchvision import transforms

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")

ネットワークの準備

  • 今回はImageNetで学習済みのinceptionv3モデルを用いる
  • このネットワークを騙していくことになる
inceptionv3 = models.inception_v3(pretrained=True)
inceptionv3.eval()

ラベルの準備

  • ImageNetのラベルをロードしておく
labels = pickle.load(urllib.request.urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl') )

画像の読み込み

  • 手元で準備したモルモットの画像をロードする
img = Image.open("../images/cavy.jpg")
plt.imshow(img)

f:id:noconocolib:20190101185739j:plain

  • PyTorchで扱えるようにTensor型に変換する
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transforms = transforms.Compose([
    transforms.Resize((299,299)),  
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
img_tensor = transforms(img).unsqueeze(0)
img_tensor.requires_grad_()

モデルで綺麗な画像を分類

  • まずは何も手を加えていない画像の分類結果を見てみる
output = inceptionv3.forward(img_tensor)
label_idx = torch.max(output.data, 1)[1][0]
x_pred = labels[label_idx.item()]
output_probs = F.softmax(output, dim=1)
x_pred_prob = torch.max(output_probs, 1)[0].item()

print(x_pred)
print("output prob: {:.4f}".format(x_pred_prob))
guinea pig, Cavia cobaya
output prob: 0.9860
  • 98.6%の確率でモルモットであると分類できている

Fast Gradient Sign Method

  •  \rm{X}: 元々の画像
  •  \rm{X}_{adv}: adversarial example
  •  \epsilon: adversarial perturbationのサイズ
  •  \nabla_\rm{X} J(\rm{X},\rm{Y}_{true}): 入力画像に対するモデルの勾配

であるとき,Adversarial Exampleは

 \rm{X}^{adv} = \rm{X} + \epsilon sign(\nabla_\rm{X} J(\rm{X}, \rm{Y}_{true}))

で生成することができる.

y_true = label_idx
target = torch.LongTensor([y_true])
loss = nn.CrossEntropyLoss()
loss_cal = loss(output, target)
loss_cal.backward(retain_graph=True)   
eps = 0.04
x_grad = torch.sign(img_tensor.grad.data)
x_adversarial = img_tensor.data + eps * x_grad
output_adv = inceptionv3.forward(x_adversarial)
x_adv_pred = labels[torch.max(output_adv.data, 1)[1][0].item()]
op_adv_probs = F.softmax(output_adv, dim=1)
adv_pred_prob = torch.max(op_adv_probs).item()

print(x_adv_pred)
print("output prob: {:.4f}".format(adv_pred_prob))
guinea pig, Cavia cobaya
output prob: 0.1138
  • 11.3%の確率でモルモットであると分類されている
  • 予測確率を下げることはできているが,まだモルモットであると分類はできてしまっている.

可視化

  • 可視化用の関数を定義する
def visualize(x, x_adv, x_grad, epsilon, clean_pred, adv_pred, clean_prob, adv_prob):
    x = x.squeeze(0)     #remove batch dimension
    x = x.mul(torch.FloatTensor(std).view(3,1,1))
    x = x.add(torch.FloatTensor(mean).view(3,1,1)).detach().numpy()
    x = np.transpose( x , (1,2,0))
    x = np.clip(x, 0, 1)
    
    x_adv = x_adv.squeeze(0)
    x_adv = x_adv.mul(torch.FloatTensor(std).view(3,1,1))
    x_adv = x_adv.add(torch.FloatTensor(mean).view(3,1,1)).detach().numpy()
    x_adv = np.transpose( x_adv , (1,2,0))
    x_adv = np.clip(x_adv, 0, 1)
    
    x_grad = x_grad.squeeze(0).detach().numpy()
    x_grad = np.transpose(x_grad, (1,2,0))
    x_grad = np.clip(x_grad, 0, 1)
    
    figure, ax = plt.subplots(1,3, figsize=(18,8))
    ax[0].imshow(x)
    ax[0].set_title('Clean Example', fontsize=20)
    
    ax[1].imshow(x_grad)
    ax[1].set_title('Perturbation', fontsize=20)
    ax[1].set_yticklabels([])
    ax[1].set_xticklabels([])
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    
    ax[2].imshow(x_adv)
    ax[2].set_title('Adversarial Example', fontsize=20)
    
    ax[0].axis('off')
    ax[2].axis('off')

    ax[0].text(1.1,0.5,
               "+{}*".format(round(epsilon,3)),
               size=15,
               ha="center",
               transform=ax[0].transAxes)
    
    ax[0].text(0.5,-0.13,
               "Prediction: {}\n Probability: {}".format(clean_pred, clean_prob)
               size=15,
               ha="center", 
               transform=ax[0].transAxes)
    
    ax[1].text(1.1,0.5, " = ", size=15, ha="center", transform=ax[1].transAxes)

    ax[2].text(0.5,-0.13,
               "Prediction: {}\n Probability: {}".format(adv_pred, adv_prob),
               size=15,
               ha="center", 
               transform=ax[2].transAxes)
    
    plt.show()
visualize(img_tensor, x_adversarial, x_grad, eps, x_pred, x_adv_pred, x_pred_prob, adv_pred_prob)

f:id:noconocolib:20190101191222p:plain

Iterative Target Class Method

  • 入力画像を狙ったクラスに誤分類させる手法として,Iterative Target Class Methodが存在する

 X^{adv}_{0}=X

 X^{adv}_{N+1} = Clip_{X,\epsilon}(X^{adv}_{N} + \alpha sign(\nabla_\rm{X} J(\rm{X}^{adv}_{N}, \rm{Y}_{true}))

y_target = torch.LongTensor([385])  #385= Indian elephant
y_target.requires_grad = False
epsilon = 0.25
n_steps = 10
alpha = 0.025
for i in range(n_steps):
    zero_gradients(img_tensor)

    output = inceptionv3.forward(img_tensor)
    loss = nn.CrossEntropyLoss()
    loss_cal = loss(output, y_target)
    loss_cal.backward()

    x_grad = alpha * torch.sign(img_tensor.grad.data)

    adv_temp = img_tensor.data - x_grad
    total_grad = adv_temp - img_tensor
    total_grad = torch.clamp(total_grad, -epsilon, epsilon)

    x_adv = img_tensor + total_grad
    img_tensor.data = x_adv
output_adv = inceptionv3.forward(img_tensor)
x_adv_pred = labels[torch.max(output_adv.data, 1)[1][0].item()]
op_adv_probs = F.softmax(output_adv, dim=1)
x_adv_pred_prob = torch.max(op_adv_probs)[0].item()

visualize(img_tensor, img_tensor.data, total_grad, epsilon, x_pred,x_adv_pred, x_pred_prob,  x_adv_pred_prob)

f:id:noconocolib:20190101192203p:plain

  • 99.9%の確率でインドゾウであると分類している
  • 見事モルモットをインドゾウに誤認識させることに成功している
  • 生成されたAdversarial Example画像も,我々人間の目にはほとんど違和感がない

Conclusion

  • 非常に簡単な方法でニューラルネットワークに対するAdversarial Attackを実装することができた
  • 今回試した手法はどちらも,攻撃者がネットワークの入出力だけではなく勾配も得ることができるホワイトボックス攻撃に分類される.これとは対照的に,ネットワークの入出力だけから攻撃を行うブラックボックス攻撃も存在する
  • この他にも様々な攻撃手法や,それに対する防御手法も提案されているため,他にも試してみたい