ディープラーニングブログ

Mine is deeper than yours!

論文解説 Depthwise Separable Convolution for Neural Machine Translation (SliceNet)

こんにちは Ryobot (りょぼっと) です.

テンソル分解は 2017 年の密かなブームだったかもしれない.
論文数は多くないが,テンソル分解を用いた手法が中規模言語モデル [1],大規模言語モデル [2],機械翻訳 (本紙) [3],動作認識 [4] で軒並み SOTA を達成している.

テンソル分解

テンソル分解 (Tensor Decomposition, Tensor Factorization) は {n} ランクのテンソル{n} 個の因子行列 (Factor Matrix) と 1 個のコアテンソル (Core Tensor, なくても OK) に分解し,パラメータを削減する手法である (ソース).すべての因子行列 (+ コアテンソル) を内積すると分解前のテンソルに近似する.

f:id:Ryobot:20171222185355p:plain:w500

身近な例では 2 階のテンソル分解が行列分解 (Matrix Factorization) であり,0 階のテンソル (スカラー) の分解は中学校で習う因数分解である.

大規模データは大規模モデルで愚直に殴るのが最も有効であり,現実的な時間内で訓練するためにテンソル分解や条件付き計算が注目を集めている.

大規模言語モデルの SOTA

テンソル分解は大規模モデルのパラメータ削減で力を発揮する.成功例として巨大な LSTM 層を行列分解する手法を紹介したい.

LSTM は次のような関数である.

f:id:Ryobot:20171222185414p:plain:w300

LSTM の各ゲートは次式によって表される.

f:id:Ryobot:20171222185421p:plain:w250

ここで {x_t \in \mathbb{R}^{p}}{h_t \in \mathbb{R}^{p}} であり,{T : \mathbb{R}^{2p} \rightarrow \mathbb{R}^{4n}} はアフィン変換 {T = W * [x_t, h_{h-1}] + b} である.

アフィン変換 {T} の重み行列 {W \in \mathbb{R}^{4n \times 2p}} の計算コストが大きいので行列分解によってパラメータ数を削減したい.

下図は左から 2 層の一般的な LSTM,2 層の F-LSTM,2 層かつ各層 2 グループの G-LSTM である.ただし {d = (x, \, h)}{d1 = (x^1, \, h^1)}{d2 = (x^2, \, h^2)} とする.

f:id:Ryobot:20171222185346p:plain:w700

F-LSTM (Factorized LSTM) は重み行列 {W} を小さな 2 個の行列 {W1}{W2}内積 {W \approx W2 * W1} に近似させる.ここで {W1 \in \mathbb{R}^{2p \times r}}{W2 \in \mathbb{R}^{r \times 4n}}{r} < {p} <= {n} である.パラメータ数は {(2p * 4n)} から {(2p * r + r * 4n)} に削減される.

G-LSTM (Group LSTM) は LSTM と入力 {x_t} と隠れ層 {h_t}{k} 個の独立なグループに分離する.つまり {h_t^i}{x_t^i}{h_{t-1}^i}{T^i} のメモリ状態にのみ依存するように,{k} 個のベクトルの連結 (concatenate) {x_t = (x_t,^1 \, \ldots, \, x_t^k)}{h_t = (h_t,^1 \, \ldots, \, h_t^k)} に分割し,次式のように独立して計算する.

f:id:Ryobot:20171222185425p:plain:w510

ここで {T^j : \mathbb{R}^{2p/k} \rightarrow \mathbb{R}^{4n/k}} は グループ {j} のアフィン変換である.パラメータ数は {1/k} に削減される.

データセットは単語数 {829} M,語彙数 {793471} のニュース記事 [Chelba, 2013] から成る Google Billion Word [Chelba, 2013] を使用する.

8 枚の Tesla P100 GPU が刺さった DGX-1 で 1 週間訓練した.

f:id:Ryobot:20171222185403p:plain:w700

ここで埋め込みサイズ {p} は 1024,メモリサイズ {n} は 8192,F512 の行列 {W} の中間サイズ {r} は 512,G-4 のグループ数 {k} は 4 (G-16 は {16}) とする.

ベースライン [Jozefowicz, 2016] の BIGLSTM が 31.0,行列分解を用いた F-LSTM が 28.11,グループ分離を用いた G-LSTM が 28.17 のパープレキシティを達成した.

f:id:Ryobot:20171222185408p:plain:w700

また G-LSTM を 3 週間訓練したところ,パープレキシティは SOTA である 23.36 を達成した (現在も SOTA).

SliceNet

SliceNet は巨大な ResNet の畳み込み層を Depthwise Convolution (空間方向の畳み込み) と Pointwise Convolution (チャネル方向の畳み込み) に分解する Depthwise Separable Convolution を用いたニューラル機械翻訳である.

著者は Keras 作者のショレー氏であり,SliceNet は同氏が開発した画像認識向けの Xception モデルを機械翻訳向けに移植したものと言える.

WMT'14 の BLEU スコアは英仏: -, 英独: 26.1 で第 2 位 (登場時 1 位)

f:id:Ryobot:20171222185339p:plain:w700

SliceNet は ConvModule を 6 層スタックしたエンコーダと 4 層スタックしたデコーダから成る.
ConvModule は 4 層の ConvStep から成り,ConvStep は Depthwise Separable Convolution に層正規化 (Layer Normalization) を適応したものである.
IOMixer と Decoder は一般的な注意 (ie, Source-Target-Attention) を使用する.

Depthwise Separable Convolution

SepConv (Depthwise Separable Convolution) は前述のとおり 2 つの畳み込みに分解してパラメータを削減する手法である.解説はこちらがわかりやすい.

f:id:Ryobot:20171223152513p:plain:w700

  • Depthwise Convolution : 空間方向の畳み込みは入力のすべてのチャネルに対して独立に計算する.
  • Pointwise Convolution : チャネル方向の畳み込みは一般的な 1x1 窓の畳み込みであり,Depthwise Convolution で計算したチャネルを新しいチャネルに投射する.

分解によって下図のようにカーネルのパラメータが削減される.

f:id:Ryobot:20171222185324p:plain:w700

SepConv は次式のように表される.

f:id:Ryobot:20171222185439p:plain:w770

ここで {\odot} は要素ごとの積である.

また前述の G-LSTM のように畳み込み層のカーネルと入力を {g} 個のグループに分解する SuperSC (Super-separable convolution) によって更にパラメータを削減できる.

f:id:Ryobot:20171222185446p:plain:w780

Deconvolution と Dilated Convolution

畳み込み (Convolution) の亜種にチェッカーボードのように間隔を空ける手法が 2 つある.

f:id:Ryobot:20171223152503p:plain:w700

  • 逆畳み込み (Deconvolution): 入力の各成分の間隔を空ける.入力を大きくできる.
  • 拡張畳み込み (Dilated Convolution): フィルタ (カーネル) の各成分の間隔を空ける.受容野を広くできる.

つまり両者は対局的な畳み込みである.

SliceNet では広い文脈の依存関係を参照することが可能な Dilated Convolution を使用する.

畳み込みモジュール

ConvStep は SepConv に層正規化 (Layer Normalization) を適応したもので次式によって表される.

f:id:Ryobot:20171222185456p:plain:w480

層正規化は名前のとおり層内の各成分の統計量を計算し,正規化する手法である.

f:id:Ryobot:20171222185452p:plain:w740

畳み込みモジュール (Convolution Module) は ConvStep を 4 層スタックし残差接続 (Residual Connection) を取り入れたモジュールである.

f:id:Ryobot:20171222185500p:plain:w580

注意モジュール

内積注意は 2 つのテンソル {source \in \mathbb{R}^{m \times depth}}{target \in \mathbb{R}^{n \times depth}} を入力として受け取り,{depth} でスケーリングした次式によって文脈ベクトルを得る.

f:id:Ryobot:20171222185506p:plain:w650

また注意が各単語の位置情報にアクセスできるように正弦波を埋め込み行列に加算する.詳しくは 論文解説 Attention Is All You Need (Transformer) を参照されたい.

f:id:Ryobot:20171222185511p:plain:w380

実用的に {target} には位置情報を付与して 2 度 SepConv を適応している.

f:id:Ryobot:20171222185516p:plain:w770

実験と結果

データセットは WMT'14 の英独 (5M 対訳文) を使用する.英仏は評価しておらず,かなり実験が雑である.

ともあれ結果は次のとおり.

f:id:Ryobot:20171222185430p:plain:w450

BLEU スコアは SuperSC を用いた SliceNet が 26.1 を達成し SOTA となった (この後すぐに Transformer が登場し三日天下となった).