PyTorchでGrad-CAMによるCNNの可視化.

Grad-CAMはConvolutional Neural Networksの可視化手法の一種.CNNが画像のどの情報を元にして分類を行なっているのかを可視化するのに用いられる.

今回はこのGrad-CAMをPyTorchで試してみる.

モジュールのインポート

  • 必要なライブラリをimportしておく
%matplotlib inline
import urllib
import pickle
import cv2
import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("white")

ImageNetのラベルをロード

  • 実験にはImageNetでの学習済みモデルを用いるので,ImageNetのラベル情報をロードしておく
labels = pickle.load(urllib.request.urlopen('https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl') )

GradCAMクラスの定義

GradCAMは,

  • 畳み込みの最終層でGlobalAveragePoolingを行う
  • あるクラスにおける最終層の各チャンネルの重要度を決定
  • 重要度に応じて各チャンネルをかけて足し合わせる
  • 足し合わせたものをRelu関数に通す

で求められる.

class GradCam:
    def __init__(self, model):
        self.model = model.eval()
        self.feature = None
        self.gradient = None

    def save_gradient(self, grad):
        self.gradient = grad

    def __call__(self, x):
        image_size = (x.size(-1), x.size(-2))
        feature_maps = []
        
        for i in range(x.size(0)):
            img = x[i].data.cpu().numpy()
            img = img - np.min(img)
            if np.max(img) != 0:
                img = img / np.max(img)

            feature = x[i].unsqueeze(0)
            
            for name, module in self.model.named_children():
                if name == 'classifier':
                    feature = feature.view(feature.size(0), -1)
                feature = module(feature)
                if name == 'features':
                    feature.register_hook(self.save_gradient)
                    self.feature = feature
                    
            classes = F.sigmoid(feature)
            one_hot, _ = classes.max(dim=-1)
            self.model.zero_grad()
            one_hot.backward()

            weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
            
            mask = F.relu((weight * self.feature).sum(dim=1)).squeeze(0)
            mask = cv2.resize(mask.data.cpu().numpy(), image_size)
            mask = mask - np.min(mask)
            
            if np.max(mask) != 0:
                mask = mask / np.max(mask)
                
            feature_map = np.float32(cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET))
            cam = feature_map + np.float32((np.uint8(img.transpose((1, 2, 0)) * 255)))
            cam = cam - np.min(cam)
            
            if np.max(cam) != 0:
                cam = cam / np.max(cam)
                
            feature_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2RGB)))
            
        feature_maps = torch.stack(feature_maps)
        
        return feature_maps

入力画像の読み込み

  • 実験で使うテスト画像を読み込んでおく
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_image = Image.open("../images/cavy.jpg")
test_image_tensor = (transform((test_image))).unsqueeze(dim=0)

image_size = test_image.size
print("image size: ", image_size)

plt.imshow(test_image)

f:id:noconocolib:20190107061938p:plain

モデルの定義

  • 今回はImageNetで学習済みのVGG19を用いる
model = models.vgg19(pretrained=True)

Grad-CAMによる可視化

  • Grad-CAMによる可視化を行う
grad_cam = GradCam(model)

feature_image = grad_cam(test_image_tensor).squeeze(dim=0)
feature_image = transforms.ToPILImage()(feature_image)

pred_idx = model(test_image_tensor).max(1)[1]
print("pred: ", labels[int(pred_idx)])
plt.title("Grad-CAM feature image")
plt.imshow(feature_image.resize(image_size))

f:id:noconocolib:20190107062133p:plain
pred: guinea pig, Cavia cobaya

  • めっちゃ顔を見てる

別の画像で試す

  • 別の画像でも試してみる
test_image = Image.open("../images/cavy2.jpg")
test_image_tensor = (transform((test_image))).unsqueeze(dim=0)

image_size = test_image.size
print("image size: ", image_size)

plt.imshow(test_image)

f:id:noconocolib:20190107062221p:plain
image size: (640, 427)

feature_image = grad_cam(test_image_tensor).squeeze(dim=0)
feature_image = transforms.ToPILImage()(feature_image)

pred_idx = model(test_image_tensor).max(1)[1]
print("pred: ", labels[int(pred_idx)])
plt.title("Grad-CAM feature image")
plt.imshow(feature_image.resize(image_size))

f:id:noconocolib:20190107062301p:plain
pred: hamster

  • 同じく顔を見てる(ハムスターに誤分類されてて悲しい)