Steerable CNNs の紹介

5. 実装

ここでは、実装にあたり、工夫や注意点などについていくつか説明します。

5.1. ネットワーク構造

ResNetの変種である、Wide ResNetに似た構造を用いています。Residual unitの中間層と、他の部分で表現を変えるようになっています。具体的な内容はソースコードを参照してください。

5.2. 使用する表現

実装の容易さや表記の簡略さのため、今回は、同じ表現を複数回繰り返した、以下のように定義される表現をチャネル方向の表現 ( \mathbb{R}^C 上の表現) に用います。 \mathrm{repeat}(\rho, n) = \rho \oplus \rho\cdots \oplus \rho. 右辺では n 個の直和をとっています。なお、論文では capsule という名前でもう少し一般的な構造が定義されています。

5.3. Intertwiners のなす空間の基底の求め方

Convolution 層の入力の表現の次数を C、出力の表現の次数を C' とします。Intertwiners のなす空間 \mathrm{Hom}_H(\pi, \rho') の基底を求めるために、行列として計算を行います。Intertwiner は任意の h \in H に対して

M(f)M(\pi(h)) = M(\rho'(h)) M(f)

を満たすような、 C'C 列の行列 M(f) によって表されます。

両辺の ij 列目に着目すると、等式

\sum_{j' \in \{1,2,\dots, C\}} M(f)_{ij'} \, M(\pi(h))_{j'j} - \sum_{i' \in \{1, 2, \dots C'\}} M(\rho(h))_{ii'} \, M(f)_{i'j} = 0

が成り立ちます。全ての h \in H, i = 1,2,\dots, C', j = 1,2,\dots, C について等式を立てて、連立一次方程式の解空間を求めれば基底を求めることができます。

CC' 個の変数、 |H|CC' 個の制約を持つ連立一次方程式は、 |H|CC'CC' 列の行列として表すことができます。この行列を A とします。 Av = 0 を満たすような CC' 次元ベクトル v の空間、すなわち A の零空間を計算することで、連立方程式の解空間が求められます。

零空間を求める方法は色々ありますが、ここではランク付き QR 分解 (rank revealing QR decomposition) を紹介します。より知名度が高いと思われる特異値分解 (SVD) よりも高速です。ランク付き QR 分解によって A^T

A^TP = QR

と分解します。ここで、 Q は直交行列、 R は上三角行列、 P は置換行列です。 R の対角成分の絶対値は広義単調減少で、非零な対角成分の個数が R のランクとなります。P は等式の入れ替えを行っているだけなので無視して、 A^T = QR とします。 Q は直交行列なので Q^TA^T = R が成り立ちます。これを転置すると AQ = R^T になります。R^T の非零な対角成分の個数を r とすると R の性質より R^Tr 列目以降は零であるので、 Qr 列目以降が零空間になります。

5.4. Convolution 層の重みの初期化

Convolution 層の入力のチャネル方向の表現を \mathrm{repeat}(\rho, m)、出力のチャネル方向の表現を \mathrm{repeat}(\rho', m') とし、それぞれ次数が mc および m'c' であるとします。Convoluiton 層のフィルタサイズは k \times k とします。

\rhok\times k のパッチに拡張した k^2C 次元表現と C' 次元表現 \rho' の間の intertwiner を考えます。Convolution 層は \mathrm{im2col} をして考えると m'C'k^2mC 列の重み行列で書けるような線型変換です。Intertwiner の成す空間が h 次元であるとすると、学習すべきパラメータ数は mm'h 個となります。Steerable CNNs では重み行列の成分数と学習すべきパラメータ数は一般には一致しないように設計していることを思い出しておきます。

mm'h 個のパラメータをどう初期化すべきかを考えます。
例として、重み行列の各要素を分散 2 / k^2\,\mathrm{channels}_\mathit{in} の正規分布で初期化する、MSRA 初期化 (MSRA は Microsoft Research Asia の略称、Heの初期化とも呼ばれる) を考えます。まず、各基底の L2 ノルムが \sqrt{k^2CC'/ h} となるように基底をスケーリングします。この場合、 mm'h 個のパラメータのそれぞれを分散 \sigma^2 の正規分布で初期化すれば、重み行列の各要素の分散も \sigma^2 になります。MSRA 初期化にできるだけ従うためには m'C'k^2mC 列の重み行列の各要素の分散を 2 / k^2mC にしたいため、パラメータは分散 2 / k^2mC の正規分布から初期化するのが適切と考えられます。

5.5. Batch normalization について

ResNet の構造には、batch normalization を使います。Batch normalization はチャネルごとにかつ batch 方向に正規化を行います。そのため、そのまま使うとチャネルを入れ替える今回の作用に関しては、同変になりません。著者らは、この後の論文などで明確にそのことについて触れ対応もしているようです [WGWBC18]。この記事の実装では、簡易化のため、batch normalization はそのまま用いました。

6. 実験

6.1. CIFAR-100による実験

CIFAR-100 データセットで学習を行いました。
  • バッチサイズ: 64
  • オプティマイザ: SGD with momentum= 0.9
  • Weight decay: 5\mathrm{e}-4
  • エポック数: 150
  • 学習率: 0.05 から開始し 60120 エポック目で 0.2 にした
の設定で最もよい結果が得られました。学習終了時の検証セット accuracy は 83.13\% で、最大値は 140 エポック目の 83.32\% でした。著者が報告している 81.18\% という精度よりも、 2\% ほど高い精度を達成できました。学習率のスケジューリングや weight decay の値の違い、あるいは後述の augmentation の違いが、精度に差が出た要因として考えられます。
この画像には alt 属性が指定されておらず、ファイル名は accuracy.png です
訓練セットおよび検証セットに対する accuracy の学習曲線
Augmentation は以下の通りです。
  • 確率 0.5 で水平反転
  • 確率 0.5 で平行移動
    • 平行移動する場合、縦方向横方向それぞれについて、 -4 以上 4 以下の整数値から一様に乱択した値を移動量とする
    • パディングは 0(黒)で行う
同じ著者による別論文 [CW16] の実装では、平行移動量として実数を許容して、bicubic 法でリサンプリングしつつ平行移動を行っていました。この論文の実装でも同じ augmentation を行っていると推測されます。

6.2. 特徴量マップの様子

学習したモデルが Steerable なネットワークになっていることを確認するために、画像の特徴量マップと回転した画像の特徴量マップを比較しました。
  • 画像内に回転中心のピクセルを作るために、CIFAR-100 の 32 \times 32 の画像を左と上に 1 pxパディングし、 33\times 33 に変換したものを入力画像として使用した。
  • 以下の 1. と 2. の特徴量マップを計算した。
    1. 未回転の画像をネットワークに入れ、出力側表現の90度回転に相当する作用を施して得られる特徴量マップ
    2. 90度回転した画像をネットワークに通して得た特徴量マップ
  • Batch normalization を使うと steerable にならないことが分かっているので、1. と 2. を計算するときに batch normalization による正規化を行わないようにした。
以下の図は特徴量マップの様子を各チャネルごとに画像としたものです。上段の画像が 1. 下段が 2. で、横軸がチャネルに対応しています。ただし、特徴量マップはほとんどのチャネルが 0 となっていたので、値が入っているチャネルを恣意的に選んでいます。上下を比べてみると見た目にはほとんど変わらないことが分かります。実際、1. と 2. の二乗誤差は 約 1.67\mathrm{E}-19 であり、誤差の範囲と言えるかと思います。
この画像には alt 属性が指定されておらず、ファイル名は feature_map.png です
学習したモデルが出力する特徴量マップの様子。(左) 入力画像 (右上段)入力画像の特徴量マップに対して出力側表現の 90 度回転に相当する作用を施したもの (右下段)90 度回転した入力画像の特徴量マップ

古川

最近は主に動画像の分析を行なっています。深層学習に興味があります。大学院時代は位相幾何学をしていました。