SNIP: Single-Shot Network Pruning Based Oo Connection Sensitivity

ICLR2019採択論文"SNIP: Single-Shot Network Pruning Based Oo Connection Sensitivity"をレビュー.

元論文はこちら:

openreview.net

パラメータ数の多い畳み込みニューラルネットワークを,分類精度をほとんど落とさずに大幅にスパース化するnetwork pruningについての研究.

Abstract

network pruningは,与えられた巨大なニューラルネットワークを,精度を維持したまま小規模化することが目的.

既存のnetwork pruningの手法は,再学習のプロセスや追加のハイパーパラメータが必要など,pruningそのものが高コストになってしまうものが多い.

これを受けて,学習前にデータに基づいてネットワークの接続の重要度を識別するnetwork pruningの手法を提案.

提案手法は以下の特徴を満たす.

  • Simplicity: ネットワークを学習前に一度だけpruningするだけでよく,事前の学習や面倒なpruningのためのスケジューリングは必要ない.
  • Versatility: 提案する構造的に重要な接続を発見する指標は,ネットワークアーキテクチャによらず汎用的に適用できる
  • Interpretability: 提案手法は,データのミニバッチを用いてsingle-shotで重要な接続を発見できる

つまり,提案手法は既存のnetwork pruningの手法とは異なり,学習前からネットワークのスパース化を達成できる.

Neural Network Pruning

Neural network pruningの研究背景には,多くのネットワークはタスクに対してパラメータ過多であり,より小さなネットワークで同様の精度を達成できるはずであるという仮定に基づいている.

与えられたデータセット D = {x_{i}, y_{i}}^{n}_{i=1},求められるsparsity levelを \kappaとする.

neural network pruningは以下の最適化問題を解くことになる.

 \min_{w} L(w; D) = \min_{w} \frac{1}{n} \sum^{n}_{i=1} l(w; (x_{i}, y_{i})),

 s.t. w\in{\mathbb{R}^m}, |w|_0 \leq \kappa

 l(\cdot)はcross-entropyのような一般的な損失関数,wはネットワークのパラメータ集合, mはパラメータ数, |\cdot|=0 L_0ノルム.

重要度ベースのpruning手法では,冗長なパラメータや接続を重要度に基づいて削除することでネットワークの軽量化を実現する.こうした手法では,パラメータや接続の重要度を測る指標が非常に重要になる.

よく用いられるものとしては,重みパラメータの絶対値が閾値以下のものを重要でないと扱うものや,高いHessianを持つ重みの重要度が高いとする指標がある.

 s_{j} = \begin{cases}
|w_{j}|,\ for\ magnitude\ based \
\frac{w\_{j}^{2}H_{jj}}{2}\ or\ \frac{w_{j}^{2}}{2H^{-1}_{jj}},\ for\ Hessian\ based
\end{cases}

 s_jは接続 jに対する重要度.

これらの指標を用いた手法は,ネットワークの再学習が必要であったり,ネットワークアーキテクチャに依存しすぎたりする.

Single-Shot Network Pruning Based on Connection Sensitivity

ネットワークとデータセットが与えられた時,タスクとデータに基づいて冗長な接続の枝刈りをしたい.

Connection Sensitivity: Architectural Perspective

提案する指標では,接続の重要度をその重みに依存せずに決定したい.

パラメータ wの接続の状況を表現するインジケータを c\in{{0, 1}^m}とする.

network pruningの目的関数は,

 \min_{c, w} L(c\odot w; D) = \min_{c, w} \frac{1}{n} \sum^{n}_{i=1} l(c\odot w; (x_{i}, y_{i})),

 s.t.\ w\in{\mathbb{R}^m},\ c\in{{0, 1}^m},\ |c|_0 \leq \kappa

ここで, c_jは接続が有効かどうかを表現しているので, c_jを切り替えた時の目的関数の差分に注目することで,その接続の重要度を判断できる.

 \bigtriangleup{L_j (w; D)} = L(1\odot w; D) - L((1-e_j)\odot w; D)

 c_jは2値のため微分できず,最適化が難しいため実際にはこれを緩和した問題を解くことになる,

 \bigtriangleup{L_j (w; D)} = g_j (w; D) = \frac{\partial L(c\odot w; D)}{\partial{c_j}}

 g_jの勾配が大きい時は,ネットワークとタスクにとって重要な接続であるはずであるという仮説を立てる.これに基づくと,

 s_{j} = \frac{|g_{j} (w; D)|}{\sum^{m}_{k=1} |g_{k} (w; D)|}

Single-Shot Pruning at Initialization

Single-shot Network Pruning (SNIP)のアルゴリズムを以下に示す.

f:id:noconocolib:20190227130014p:plain

ネットワークアーキテクチャの変化に対するロバスト性についての議論

ニューラルネットワークの重みの初期値は通常,正規分布を用いてランダムに初期化される.

しかし,ネットワークのすべての層の重みの分散が決まった値をとっている場合,各レイヤを通る信号が同じ分散を持つことは保障されないため,勾配や本論文で注目している顕著性指標がネットワークアーキテクチャに強く依存してしまう.

これを避けるため,論文ではネットワークの初期化にvariance scaling手法を用いることを推奨している.

Pruningに用いるミニバッチに対するロバスト性についての議論

本手法ではミニバッチを用いてpruning対象のニューロンを選択するが,これはミニバッチに含まれるデータに依存してしまう.

よって,複数バッチにまたがって接続の重要度を蓄積してから最終的なpruning対象を決定するテクニックを紹介.

Experiments

SNIPを適用した複数のネットワークアーキテクチャに対して,MNISTとCIFAR-10の分類タスクを用いて実験.

LeNetのpruning

  • LeNetを,複数のsparsity levelでpruningした際の実験結果
  • 分類精度をほとんど落とさずにネットワークを大幅にスパースにできている
  • さらに,sparsity levelによっては元のネットワークより汎化性能が高くなっているケースもある.

f:id:noconocolib:20190227130102p:plain
fe Figure 1: Test errors of LeNets pruned at varying sparsity levels κ¯, where κ¯ = 0 refers to the reference network trained without pruning. Our approach performs as good as the reference network across varying sparsity levels on both the mod

先行研究との比較実験

f:id:noconocolib:20190227130127p:plain
Table 1: Pruning results on LeNets and comparisons to other approaches. Here, “many” refers to an arbitrary number often in the order of total learning steps, and “soft” refers to soft pruning in Bayesian based methods. Our approach is capable of pruning up to 98% for LeNet-300-100 and 99% for LeNet-5-Caffe with marginal increases in error from the reference network. Notably, our approach is considerably simpler than other approaches, with no requirements such as pretraining, additional hyperparameters, augmented training objective or architecture dependent constraints.