ディープラーニングブログ

Mine is deeper than yours!

対話モデルの訓練/評価フレームワーク ParlAI がすごい

こんにちは,Ryobot です.

本稿では ParlAI の基本的な使用法やユーザーが独自に定義できるエージェントの実装方法を解説しました.
以下,PyTorch, Chainer, TensorFlow によるエージェントの実装例を GitHub にて公開したので適宜ご覧ください.

目次

ParlAI とは?

目次に戻る ↩︎

ParlAI (パーレイ) は Facebook AI Research が今年 5月に公開した対話モデルの訓練/評価およびデータ管理のフレームワークである.

f:id:Ryobot:20171019191900p:plain

対話モデルの実装ではコードの大半をデータの収集と前処理・辞書の作成・ミニバッチの作成・訓練/評価のイテレーションの記述が占めることが往々にあり,モデルの実装以前で疲弊してしまう.

ParlAI はこれらの面倒な処理を全部やってくれる Python フレームワークである.データ管理を画一化しており 20 種類以上の有名なデータセットを引数指定のみで訓練/評価データに使用できる.マルチタスク学習 (合同訓練) も可能で,ParlAI 上だけで Amazon Mechanical Turk で自前のデータを収集して実験したり人手評価も実施できる (MTurk は最近日本からも依頼可能になった).勿論,バッチ学習や GPU にも対応している.

対話の研究では以下のようなタスクが存在する.

  • 質疑応答 (QA): 質問に対して適切な回答を予測する
  • 穴埋め (Cloze): 文中の空白を補完する適切な単語や文章を予測する
  • ゴール指向 (Goal): 道案内や映画館/レストランの予約のような明確な目標を遂行する
  • 雑談 (ChitChat): 文脈依存的な会話を破綻なく継続する
  • ビジュアル質疑応答 (Visual): 画像や動画の内容に関する質疑応答

ParlAI は主要な対話タスクを網羅している.8月時点で QA は 14 タスク (12 データセット) ,他はすべて 4 タスクずつ用意されている.各タスクの詳細は ParlAI で扱えるタスク一覧 を参照されたい.

  • 質疑応答 (QA): bAbI, MCTest, Movie Dialog, MTurk WikiMovies, Simple Questions, SQuAD, TriviaQA, Web Questions, WikiMovies, WikiQA, InsuranceQA, MS_MARCO
  • 穴埋め (Cloze): BookTest, Children’s Book Test, QA CNN, QA Daily Mail
  • ゴール指向 (Goal): Dialog Based Language Learning, Dialog bAbI, Movie Dialog QA Recommendations, Personalized Dialog
  • 雑談 (ChitChat): Cornell Movie, Movie Dialog Reddit, Open Subtitles, Ubuntu
  • ビジュアル質疑応答 (Visual): VQAv1, VQAv2, VisDial, MNIST_QA

難点を挙げると,チュートリアルが貧相で使い方を覚えるのに苦労した.自前の対話モデル (ユーザーエージェント) を実装するには,引数オプション・データの流れ・バッチ処理の仕方など多くのお約束を理解する必要がある.

世界とエージェント

ParlAI のコンセプトは強化学習の訓練プラットフォームにやや似ている.OpenAI Gym や Universe は環境とエージェントの間で行動・報酬・観測状態をやり取りする訓練プラットフォームであるが,ParlAI では世界内の複数のエージェント がお互いにデータを受け渡しする.

f:id:Ryobot:20171019191903p:plain

  • 世界 (World): 複数のエージェントが互いに対話する環境
  • エージェント (Agent): 教師 (データを供給) と学生 (データから学習) の 2 タイプ

世界内に複数のエージェントが存在し,各エージェントは学生か教師のいずれかである.

ParlAI では世界とエージェントを定義後にメインループにて parley() を呼んで訓練/評価を実行する.

teacher = SquadTeacher(opt)
agent = MyAgent(opt)
world = World(opt, [teacher, agent])
for i in range(num_exs):
    world.parley()
    print(world.display())

世界は必ず parley() メソッドを持って定義される.上記の world.parley() を実行すると,教師と学生は単一のフォーマットでお互いに対話する.具体的には各エージェントが観測 (observation)行動 (action) を実行し,Python の辞書型オブジェクトを介してテキスト・ラベル・報酬等をやり取りする.

def parley(self):
    for agent in self.agents:
        act = agent.act()
        for other_agent in self.agents:
            if other_agent != agent:
                other_agent.observe(act)

実際のコードでも DialogPartnerWorld クラスなどの世界で parley() メソッドを実行し,教師 Task10kTeacher と学生 Seq2seqAgent が交互に観測と行動を行っていることが窺える.(バッチ学習は BatchWorldparley() を実行する.)

class DialogPartnerWorld(World):
    def parley(self):
        """2 つのエージェント間で交互に対話する."""
        acts = self.acts
        agents = self.agents # [Task10kTeacher, Seq2seqAgent]
        acts[0] = agents[0].act()
        agents[1].observe(validate(acts[0]))
        acts[1] = agents[1].act()
        agents[0].observe(validate(acts[1]))

このように全てのエージェントは必ず observe()act() メソッドを定義する.それぞれ親クラス (Agent) をオーバーライドする形で関数を定義していく.

class Agent(object):
    def observe(self, observation)
        ...
    def act(self)
        ...
  • observe(): 別のエージェントが取った行動を記憶する
  • act(): エージェントから行動を生成する

学生 (ユーザーエージェント) の場合,observe() メソッドで 1 つのエピソードが終了するまでの標本のテキストを記憶しておき,act() メソッドで PyTorch 等の学習可能なモデルを学習して予測を返すという流れ.ちなみに学習モデルは PyTorch, Chainer, TensorFlow 等どの深層学習ライブラリで実装してもよい.

教師 (タスクエージェント) の場合,observe() メソッドでメトリックのためのデータを収集し,act() メソッドで訓練データ等を供給している.Teacher クラスは Agent クラスを継承しているので act, observe 以外にも質問数や回答に費やした時間等のメトリックを Metrics クラスのインスタンスに保持し report メソッドで返すことができる.

f:id:Ryobot:20171019191840p:plain

例えば,上図のように 2 つのエージェントが対話する世界 (DialogPartnerWorld) において bAbI の訓練データを供給する教師 (Task10kTeacher) と seq2seq の学習モデルを有した学生 (Seq2seqAgent) との対話 (辞書型オブジェクトの交換) は次のように行われる.

Task10kTeacher: {
                'text': 'Sam went to the hallway\nPat went to the bathroom\nWhere is the milk?',
                'labels': ['hallway'],
                'label_candidates': ['hallway', 'kitchen', 'bathroom'],
                'episode_done': True
}
Seq2seqAgent: {
                'text': 'hallway'
}
...

インストールから example 実行まで

目次に戻る ↩︎

百見は一動に如かず.是非手元のマシンにインストールして example を実行してみてほしい.

インストール

# レポジトリをクローン
git clone https://github.com/facebookresearch/ParlAI.git ~/ParlAI
cd ~/ParlAI
# 必須ツールのインストール
pip install -r requirements.txt
# pytorch のインストールは http://pytorch.org/ に従う
# pyzmq, regex, spacy のインストール (requirements_ext.txt がなければ個別に)
pip install -r requirements_ext.txt  
# ParlAI のインストール
sudo python setup.py develop
# ParlAI 実行
python examples/display_data.py -t babi:task1k:1
# Lua Torch インストール (必要なら)
git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install-deps;
./install.sh
source ~/.bashrc

ParlAI は pip 等でパッケージ管理できないので定期的に GitHub の本家レポジトリに追随する必要がある.以下の方法より楽な方法があれば教えて頂きたい.

upstream.sh として保存し ./upstream.sh で実行する.

#!/bin/sh
# ブランチの確認
git branch -a
# リモートリポジトリとして本家リポジトリを upstream という名前で設定
git remote add upstream https://github.com/facebookresearch/ParlAI.git
# 本家リポジトリの変更を取り出して統合する
git fetch upstream
git merge upstream/master

example の実行

bAbI データセットを表示

display_data.py を実行し,bAbI ($1000$ 標本) の task 1 から $1$ 訓練データを表示する.-t {データセット名}-n {表示数}
メインループの world.parley()repeat_label.pyRepeatLabelAgent が行動を起こし,world.display() でエージェントの行動を表示している.RepeatLabelAgent はラベル (解答) を復唱する単純なエージェントである.

python examples/display_data.py -t babi:task1k:1 -n 1
# 結果
[babi:task1k:1]: Mary travelled to the garden.
Daniel went to the office.
Where is Daniel?
[labels: office]
[cands: garden|bathroom|kitchen|office|bedroom|hallway]
   [RepeatLabelAgent]: office

seq2seq を bAbI で訓練/評価

train_model.py を実行し,bAbI の task 1 から $9000$ 標本で Seq2seqAgent を訓練し,$1000$ 標本で検証する.PyTorch の仕様で CUDA が利用可能であれば GPU で回してくれる.--gpu 0 でもOK.
-m {モデルがあるディレクトリ名} でモデルを指定する.他に -m {ディレクトリ名}/{ディレクトリ名}-m {モジュール名}:{クラス名} でも指定できる. agents.pyget_agent_module(dir_name) を参照されたい.-mf {モデルの保存や読込み用のファイル名}, -e {エポック数}, -lr {学習率}, -bs {バッチサイズ}, -hs {隠れ層サイズ}, -ltim {ログを表示する秒間}, -vtim {検証をする秒間}

mkdir -p ./parlai/agents/seq2seq/model_file
python examples/train_model.py -m seq2seq -t babi:Task10k:1 -mf './parlai/agents/seq2seq/model_file/babi' -e 20 -lr 0.5 -bs 32 -hs 128 -ltim 2 -vtim 20
# `-m seq2seq` は `-m seq2seq/seq2seq` や `-m parlai.agents.seq2seq.seq2seq:Seq2seqAgent` でもよい
# 結果
[ Using CUDA ]
[ training... ]
[ time:2s parleys:21 total_exs:672 time_left:5439977s ] {'hits@k': {1: 0.03869, 10: 0.03869, 50: 0.03869, 100: 0.03869, 5: 0.03869}, 'total': 672, 'accuracy': 0.03869, 'f1': 0.1002}
[ time:4s parleys:47 total_exs:1504 time_left:4832701s ] {'hits@k': {1: 0.1526, 10: 0.1526, 50: 0.1526, 100: 0.1526, 5: 0.1526}, 'total': 832, 'accuracy': 0.1526, 'f1': 0.1687}
...

訓練後 ./parlai/agents/seq2seq/model_file ディレクトリ内に 4 つのファイルが生成される.

  • babi: モデルを保存している.world.save_agents() を実行し,学生のクラスの save(model, open()) メソッドで書き込んでいる.
  • babi.dict: 語彙と出現回数を保存している.build_dict.build_dictDictionaryAgent.save() を実行し open().write() で書き込んでいる.
__NULL__        1000000002
__EOS__ 1000000001
__UNK__ 1000000000
.       18000
the     18000
to      18000
?       9000
...
  • babi.test, babi.valid: テストと検証それぞれで標本数・精度・F1スコア・hits@k スコアを保存している.run_eval を実行し open().write() で書き込んでいる.
test:{'accuracy': 0.999, 'total': 1000, 'hits@k': {1: 0.999, 10: 0.999, 50: 0.999, 100: 0.999, 5: 0.999}, 'f1': 0.999}
valid:{'accuracy': 1.0, 'total': 1000, 'hits@k': {1: 1.0, 10: 1.0, 50: 1.0, 100: 1.0, 5: 1.0}, 'f1': 1.0}

DrQA を SQuAD で訓練/評価

mkdir -p ./parlai/agents/drqa/model_file
python examples/train_model.py -m drqa -t squad -mf './parlai/agents/drqa/model_file/squad' -bs 32 -ltim 2 -vtim 1000
# 結果
[ training... ]
[ time:78s parleys:1 ] [train] updates = 1 | train loss = 9.84 | exs = 32
...

ディレクト

データ類は ~/ParlAI/data にダウンロードされ,データ類以外 (e.g. MemNN のコード) は ~/ParlAI/downloads にダウンロードされる.

コード類は ParlAI/parlai/ParlAI/examples/ ディレクトリに入っている.

  • core: フレームワークを構成する主要なファイル類
  • agents: タスクを処理するエージェント (対話モデル等)
  • examples: 基本的な実行ファイル (e.g. ディレクトリ構築,訓練・評価の実行,データの表示)
  • tasks: ParlAI で利用可能な様々なタスクのためのコード
  • mturk: Mechanical Turk の設定のためのコード

ParlAI Agent の実装方法

1. 実装の流れ

目次に戻る ↩︎

ParlAI で学習可能なエージェント (学生) のクラスを実装する流れを紹介する.全ての実装例は以下のレポジトリを参照されたい.

git clone https://github.com/ryonakamura/parlai_agents.git
mv parlai_agents ~/ParlAI/parlai/

で即使用できる.

バッチ学習

最初に留意すべきこととして,ParlAI はバッチ学習がやや変則的である.オンライン学習 (バッチサイズ 1) では,DialogPartnerWorld という世界をまず用意するが,バッチ学習では BatchWorld という世界を作成し,その中で雛形DialogPartnerWorld と雛形の共有データ shared から作成したバッチサイズ分の DialogPartnerWorld を用意する.ここで全ての DialogPartnerWorld には教師 (e.g. Task10kTeacher) と学生 (e.g. Seq2seqAgent) が存在する.

雛形世界の学生エージェントのみ PyTorch 等の学習可能なモデルを有し,共有データから作成した世界の学生エージェントはモデルを有さない.つまり,学生は雛形世界のエージェントのみ学習し,共有データ世界のエージェントは学習に関与しない.

class Seq2seqAgent(Agent):

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        if not shared:
            # PyTorch 等の学習モデルの初期化
            # `shared` から作成した世界のエージェントはここを実行しない

訓練は BatchWorld.parley() を呼んで (i) と (ii) を繰り返すことで行われる.

(i) 共有データ世界内の Task10kTeacher が観測なしで個別に「行動」を実行し,共有データ世界内の Seq2seqAgent が個別に「観測」を実行し,観測結果を格納する.
(ii) 雛形世界内の Seq2seqAgent が格納した観測結果を得て代表して「バッチ行動」を実行し,共有データ世界内の Task10kTeacher が個別に「観測」を実行し,観測結果を格納する.

学生エージェントのクラス

(教師を含む) 全ての ParlAI エージェントのクラスは __init__(), act(), observe() メソッドを必ず持っている.また,学生エージェントのクラスは batch_act(), batchify(), save(), load() メソッドも持っている.

  • __init__(): クラスの初期化を行う.前述通り,バッチ学習では雛形世界のエージェントのみ PyTorch 等の学習可能なモデルを初期化する.
  • act(): オンライン学習での行動である.
  • observe(): 訓練データ等を観測し,エピソード分の標本を連結して返す.
  • batch_act(): バッチ学習での行動である.バッチサイズ分の訓練データ等を引数として入力し,バッチサイズ分の出力を返す.
  • batchify(): 入力された訓練を学習モデルで処理できる変数に変換する.
  • save(): 学習モデル等を保存する.
  • load(): 保存した学習モデル等をロードする.

observe() メソッドは全ての学生エージェントで共通である.1 つのエピソードが終わるまで,過去の標本を連結して返している.

エピソードとは一連の出来事を陳述した物語である.bAbI タスク等は複数のエピソードから成り,各エピソードは複数の標本から成る.

def observe(self, observation): # エピソードの最後の標本は `{'episode_done': True}` が提供される.
    observation = copy.deepcopy(observation)
    
    if not self.episode_done: # 現時点で `self.episode_done` は前回の標本
        # 前回の標本がエピソードの終わり (つまり `self.episode_done = True`) でなければ,
        # その前回の標本で述べられた内容 `text` も想起する必要がある.
        prev_dialogue = self.observation['text'] # 現時点で `self.observation` は前回の標本
        observation['text'] = prev_dialogue + '\n' + observation['text']
        # 前回と今回の標本の `text` を足し,今回の標本を更新.`text` はエピソードの既出分になる.
    
    self.observation = observation # 今回の標本で上書き
    self.episode_done = observation['episode_done'] # 今回の標本で上書き
    return observation

batch_act() メソッドも全ての学生エージェントで概ね共通である.バッチサイズの入力を受け取り,バッチサイズの出力を返している.

def batch_act(self, observations):
    # observations:
    #       [{'label_candidates': {'office', ...},
    #       'episode_done': False, 'text': 'Daniel ... \nWhere is Mary?',
    #       'labels': ('office',), 'id': 'babi:Task10k:1'}, ...]
    batchsize = len(observations)
    batch_reply = [{'id': self.getID()} for _ in range(batchsize)]
    # [{'id': 'Seq2Seq'}, {'id': 'Seq2Seq'}, ...]

    xs, ys, valid_inds = self.batchify(observations)
    # 入力された訓練を学習モデルで処理できる変数に変換する.
    # xs は入力データの変数,ys は正解ラベルの変数,valid_inds はインデックス
    
    # ys (ラベル) があれば訓練,なければ推測
    if ys is not None:
        preds = self.train(xs, ys) # ['bedroom', ...]
    else:
        preds = self.predict(xs)

    for i in range(len(preds)):
        batch_reply[valid_inds[i]]['text'] = preds[i]

    return batch_reply # [{'text': 'bedroom', 'id': 'Seq2Seq'}, ...]

RNNAgent はすべて以下のような共通のコマンドライン引数を有して実装している.

  • -m: モデルのクラス名.e.g. parlai.parlai_agents.<directory_name>.<file_name>:<class_name>
  • -t: ParlAI のタスク名.e.g. babi:Task1k:1 or babi,cbt
  • -mf: モデルの保存やロード用のファイル名.
  • -e: エポック数.
  • -rnn: GRU ないし LSTM を選択.
  • -bs: バッチ学習時のバッチサイズ.
  • -hs: 隠れ層ベクトルや埋め込みベクトルの次元数.
  • -nl: 隠れ層数.
  • -lr: 学習率.
  • -dr: ドロップアウト率.
  • -ltim: n 秒毎にログを表示.
  • -vtim: n 秒毎に検証を実行.

bAbI タスクでの比較

bAbI タスク [Weston, 2015] はテキストベースの簡単な質疑応答 (Question Answering) である.bAbI タスクはプログラムによる自動生成によって作成されたデータセットで,20 種類の異なる形式のタスク (各 10000 標本) が用意され,各タスクは推論・演繹・帰納・勘定などの能力を測ることができる.全てのタスクを合わせても文体の種類は数十に限られ,語彙数は 200 単語以下と少ない.また,人手評価では理論上 100 パーセントの精度で正解することができ,95 パーセント以上の精度で正しい回答を行うことでタスクに成功したと評価する.

bAbI タスクの出題例を下図に示す.黒字が事実文,青字が質問文,赤字が解答である.

f:id:Ryobot:20171019191852p:plain

例えば,「りんごはどこですか?」という質問文に対し,複数の事実文の中から回答に必要な情報「サムはベッドルームに行った」と「サムはりんごを落とした」を参照することで,モデルは「ベッドルーム」と回答することができる.

ParlAI は bAbI タスクに対応している.bAbI タスクでは 20 種類のタスクを個別に学習することを「個別訓練 (Single Training)」,全てのタスクを1つのモデルで同時に学習することを「合同訓練 (Joint Training)」と呼び,本稿では合同訓練における再現実験の結果を掲載している.詳細は 論文解説 Memory Networks (MemNN) - bAbI タスク でも解説している.

先行研究 [Sukhbaatar, NIPS 2015] では,LSTM によるテスト精度の平均は 63.6 % である.本稿で実装した RNNAgent は PyTorch, Chainer, TensorFlow 共に同程度の精度を達成している.

f:id:Ryobot:20171019192114p:plain

タスク 19 (移動パスの発見) は正解ラベルが 2 単語であるが,本稿の RNNAgent は1単語しか生成しない実装である.
また,タスク 8 (リスト,セット) は正解ラベルが 2 単語以上の場合があるが,多くは 1 単語である.

1 GPU で 1000 イテレーションを実行した時の速度比較は次のとおり.

PyTorch Chainer TensorFlow
合同訓練 96 sec 157 sec 320 sec
個別訓練 (Task 1) 36 sec 71 sec 64 sec

TensorFlow は truncated BPTT を使用していないので遅いっぽい.

2. PyTorch で RNNAgent を実装する

目次に戻る ↩︎

PyTorch は ParlAI エージェントを実装するのに最適な深層学習ライブラリであると思う.
PyTorch は GPU 使用時 1.5 ~ 2 倍くらい Chainer より速い.

from torch.autograd import Variable
from torch import optim
import torch.nn as nn
import torch

class RNNAgent(Agent):

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        if not shared:
            # 略
            self.embedding = nn.Embedding(vs, hs, padding_idx=0,
                            scale_grad_by_freq=True)
            if self.rnn_type == 'GRU':
                self.rnn = nn.GRU(hs, hs, nl, dropout=dr)
            elif self.rnn_type == 'LSTM':
                self.rnn = nn.LSTM(hs, hs, nl, dropout=dr)
            self.dropout = nn.Dropout(dr)
            self.projection = nn.Linear(hs, vs)
            self.softmax = nn.LogSoftmax()
            self.loss = nn.NLLLoss()
            # 略

    def observe(self, observation):
        # 略

    def init_zeros(self, bs=1):
        # 略
        return Variable(h0), Variable(c0)

    def forward(self, xs, drop=False):
        # 入力層
        out = self.embedding(xs)
        out = torch.transpose(out, 0, 1) # out: time x batch x hidden
        if drop:
            out = self.dropout(out)
        # RNN
        h0, c0 = self.init_zeros(len(xs)) # h0, c0: layer x batch x hidden
        if self.rnn_type == 'GRU':
            out, hn = self.rnn(out, h0) # out: time x batch x hidden
        elif self.rnn_type == 'LSTM':
            out, (hn, cn) = self.rnn(out, (h0, c0)) # Same as above
        out = out[-1] # out: batch x hidden
        # 出力層
        if drop:
            out = self.dropout(out)
        out = self.projection(out)
        out = self.softmax(out) # out: batch x vocab

        preds = []
        _, idx = out.max(1) # idx: batch x 1
        for i in idx:
            token = self.vec2txt([i.data[0]])
            preds.append(token)
        return out, preds

    def train(self, xs, ys):
        self.rnn.train()
        out, preds = self.forward(xs, drop=True)
        y = ys.select(1, 0) # y: batch
        loss = self.loss(out, y)
        self.zero_grad()
        loss.backward()
        self.update_params()
        return preds

    def predict(self, xs):
        # 略
        return preds

    def batchify(self, obs):
        """バッチサイズの観測 `text`, `label` をランク 2 のテンソル `xs`, `ys` に変換."""
        # 略
        return xs, ys, valid_inds

    def batch_act(self, observations):
        # 略
        return batch_reply

3. Chainer で RNNAgent を実装する

目次に戻る ↩︎

ParlAI の Agent クラスと Chainer の Chain クラスの多重継承を用いて実装している.
RNN は links.NStepGRU ないし links.NStepLSTM から選択できる.

import numpy as np
import chainer
from chainer import cuda
import chainer.functions as F
import chainer.links as L

class RNNAgent(Agent, chainer.Chain):

    def __init__(self, opt, shared=None):
        super(RNNAgent, self).__init__(opt, shared)
        if not shared:
            # 略
            super(Agent, self).__init__(
                            embedding = L.EmbedID(vs, hs),
                            projection = L.Linear(hs, vs))
            if self.rnn_type == 'GRU':
                super(Agent, self).add_link('rnn', L.NStepGRU(nl, hs, hs, dr))
            elif self.rnn_type == 'LSTM':
                super(Agent, self).add_link('rnn', L.NStepLSTM(nl, hs, hs, dr))
            self.dropout = F.dropout
            self.softmax = F.softmax
            self.loss = F.softmax_cross_entropy
            # 略

    def observe(self, observation):
        # 略

    def forward(self, xs, drop=False):
        # 入力層
        out = self.embedding(xs) # out: batch x time x hidden
        if drop:
            out = self.dropout(out, ratio=self.dropout_rate)
        # RNN
        out = [out[i] for i in range(len(out.data))] # out: [*(time x hidden,) * batch]
        if self.rnn_type == 'GRU':
            hy, out = self.rnn(hx=None, xs=out) # out: [*(time x hidden,) * batch]
        elif self.rnn_type == 'LSTM':
            hy, cy, out = self.rnn(hx=None, cx=None, xs=out) # Same as above
        out = [o[-1].reshape(1, -1) for o in out] # out: [*(1 x hidden,) * batch]
        out = F.concat(out, axis=0) # out: batch x hidden
        # 出力層
        if drop:
            out = self.dropout(out, ratio=self.dropout_rate)
        out = self.projection(out)

        preds = []
        idx = F.argmax(out, axis=1) # idx: batch x 1
        for i in idx:
            token = self.vec2txt([i.data])
            preds.append(token)
        return out, preds

    def train(self, xs, ys):
        with chainer.using_config('train', True):
            out, preds = self.forward(xs, drop=True)
            y = F.transpose(ys, axes=(1, 0))[0] # y: batch
            loss = self.loss(out, y)
            self.zero_grad()
            loss.backward()
            loss.unchain_backward()
            self.update_params()
        return preds

    def predict(self, xs):
        # 略
        return preds

    def batchify(self, obs):
        """バッチサイズの観測 `text`, `label` をランク 2 のテンソル `xs`, `ys` に変換."""
        # 略
        return xs, ys, valid_inds

    def batch_act(self, observations):
        # 略
        return batch_reply

links.NStepLSTM は便利だが,入出力がリスト ([*(time x hidden,) * batch]) なのがやや扱いづらい.自動的に 0 パディングする機能をなくして,PyTorch のように入出力を 1 つの変数 (batch x time x hidden) にまとめた関数があると便利かもしれない.

4. TensorFlow で RNNAgent を実装する

目次に戻る ↩︎

TensorFlow は静的計算グラフだと思われがちだが tf.nn.dynamic_rnn を用いて可変長のデータを処理できる.
dynamic_rnn は内部で control_flow_ops.while_loop (tf.while_loop と同様) を用いて時系列展開によるループ処理を実行している.

import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib

class RNNAgent(Agent):

    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        if not shared:
            # 略
            self.reuse = None if opt['datatype'] == 'train' else True
            self.create_model()
            self.saver = tf.train.Saver()
            init = tf.global_variables_initializer()
            self.sess = tf.Session()
            self.sess.run(init)
            # 略

    def observe(self, observation):
        # 略

    def create_model(self):
        reuse = reuse=self.reuse
        self.drop = tf.placeholder(tf.bool)
        
        self.xs = tf.placeholder(tf.int32, [None, None])
        # 入力層
        init = tf.contrib.layers.xavier_initializer()
        with tf.variable_scope("embeddings", reuse=reuse):
            embeddings = tf.get_variable("var", shape=[vs, hs], initializer=init)
        out = tf.nn.embedding_lookup(embeddings, self.xs) # out: batch x time x hidden
        out = tf.cond(self.drop, lambda: tf.layers.dropout(out, rate=dr), lambda: out)
        # RNN
        if self.rnn_type == 'GRU':
            rnn_cell = tf.nn.rnn_cell.GRUCell(hs, reuse=reuse)
        elif self.rnn_type == 'LSTM':
            rnn_cell = tf.nn.rnn_cell.LSTMCell(hs, reuse=reuse)
        prob = tf.cond(self.drop, lambda: 1.-dr, lambda: 1.)
        rnn_cell = tf.contrib.rnn.DropoutWrapper(rnn_cell, output_keep_prob=prob)
        multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell] * nl)
        out, _ = tf.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=out, dtype=tf.float32)
        out = tf.transpose(out, perm=[1, 0, 2]) # out: time x batch x hidden
        out = out[-1] # out: batch x hidden
        # 出力層
        logits = tf.layers.dense(out, vs, activation=None, reuse=reuse)
        out = tf.nn.softmax(logits, dim=-1) # out: batch x vocab

        preds = []
        self.idx = tf.argmax(out, axis=1) # idx: batch
        
        self.ys = tf.placeholder(tf.int32, [None, None])
        ys = tf.transpose(self.ys)[0] # y: batch
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(
                        ys, depth=vs, dtype=tf.float32), logits=logits,)
        loss = tf.reduce_mean(loss)
        optimizer = tf.train.GradientDescentOptimizer(lr)
        self.update_params = optimizer.minimize(loss)
        
    def train(self, xs, ys):
        idx, _ = self.sess.run([self.idx, self.update_params],
                        feed_dict={self.xs: xs, self.ys: ys, self.drop: True})
        preds = [self.vec2txt([i]) for i in idx]
        return preds

    def predict(self, xs):
        # 略
        return preds

    def batchify(self, obs):
        """バッチサイズの観測 `text`, `label` をランク 2 のテンソル `xs`, `ys` に変換."""
        # 略
        return xs, ys, valid_inds

    def batch_act(self, observations):
        # 略
        return batch_reply

TensorFlow は名前空間に気をつけなければならない.
ParlAI は訓練後に検証とテストを行う (train_model.py の main の最後) ので,保存したモデルから検証とテスト用のユーザーエージェントを作り直すのだが,このとき学習モデルの全ての変数は reuse=True にする必要がある.
なぜなら訓練時のエージェントが作成した TensorFlow の名前空間は維持されたままなので,もし reuse=None のままユーザーエージェントを作り直してしまうと,同じ変数に別の名前が割当られてしまい,モデルを保存したときの変数の名前と不一致になってリロードに失敗する.
よって以下のように訓練とそれ以外で reuse の値を切り替えるとよい.

reuse = None if opt['datatype'] == 'train' else True
...
with tf.variable_scope("embeddings", reuse=reuse):
    embeddings = tf.get_variable(...)
...
logits = tf.layers.dense(..., reuse=reuse)

5. PyTorch で AttentionAgent (seq2seq with Attention) を実装する

目次に戻る ↩︎

ソフトな注意 (Soft Attention) とは行列 (ベクトルの配列) に対して注意の重みベクトルを求め,行列と重みベクトルを内積して文脈ベクトルを得ることである.

メモリ {M} の各ベクトルを $\boldsymbol{m_i}$,クエリを $\boldsymbol{q}$ とおく.ニューラル機械翻訳の場合,$\boldsymbol{m_i}$ は Encoder の各中間層であり,$\boldsymbol{q}$ は Decoder の中間層に相当する.

注意の重みベクトル $\boldsymbol{a}$ (Alignment Weight Vector) は次式のスコア関数の出力をソフトマックス関数で正規化することで計算できる.これは各 $\boldsymbol{m_i}$ と $\boldsymbol{q}$ の関連性を評価する関数である.

{\mathrm{score} \left( \boldsymbol{q} , \boldsymbol{m}_i \right) = \left \{\begin{array}{ll}
{\boldsymbol{q}}^{\mathrm{T}}\boldsymbol{m_i} & dot\\
{\boldsymbol{q}}^{\mathrm{T}}\boldsymbol{W}\boldsymbol{m}_i & general \\
{\boldsymbol{v}}^{\mathrm{T}} \tanh \left( \boldsymbol{W} [\boldsymbol{q} ; \boldsymbol{m}_i] \right) & concat
\end{array}\right.}

メモリ {M} 内の事実文の数を $N$ とすると,次式で重みベクトル $\boldsymbol{a} (\boldsymbol{q}, M)$ が求まる.

$\begin{eqnarray}\displaystyle a(\boldsymbol{q}, \boldsymbol{m_i}) &=& \mathrm{softmax} \left( \mathrm{score} \left(\boldsymbol{q} , \boldsymbol{m_i} \right) \right) \ & & \ &=& \frac{ \exp \left( \mathrm{score} \left( \boldsymbol{q} , \boldsymbol{m_i} \right) \right) }{ \sum_{j=1}^{N} \exp \left( \mathrm{score} \left( \boldsymbol{q} , \boldsymbol{m_j} \right) \right) } \end{eqnarray}$

重みベクトル $\boldsymbol{a} (\boldsymbol{q}, M)$ の各成分 $a(\boldsymbol{q}, \boldsymbol{m_i})$ (スカラー) と対応する $\boldsymbol{m_i}$ をそれぞれ掛けたベクトルの総和が文脈ベクトル (Context Vector) $\boldsymbol{c}$ であり,ソフトな注意の出力である.

$\boldsymbol{c} = \sum_{i=1}^{N} a(\boldsymbol{q}, \boldsymbol{m_i}) \boldsymbol{m_i}$

ソフトな注意では重みベクトル $\boldsymbol{a}$ が $\boldsymbol{m_s}$ のアライメントないしアノテーションの役割を果たすように学習する.

つまり,選択したいベクトル {\boldsymbol{m_s}} にかける重み {a(\boldsymbol{q}, \boldsymbol{m_s})} が $1$,選択しなくないベクトル {\boldsymbol{m_{s' \neq s}}} にかける重み {a(\boldsymbol{q}, \boldsymbol{m_{s' \neq s}})} が $0$ になるようにモデルパラメータを学習する.したがって自動的にアライメントないしアノテーションが実行される.詳細は 深層学習による自然言語処理 - ニューラル機械翻訳への理論 - 注意 (Attention) でも解説している.

ニューラル機械翻訳における注意の全体像は下図の通り.

f:id:Ryobot:20171019192147p:plain

本稿における各モデルのテスト精度の比較は次の通り.

f:id:Ryobot:20171019192015p:plain

seq2seq with Attention 固有の追加引数は次の通り.

  • -bi: True なら,第 1 層に双方向エンコーダを使用する.
  • -atte: True なら,NMT [Luong, 2015] の注意を使用する.
  • -cont: True なら,デコーダの最終出力にデコーダのクエリを連結せず文脈ベクトルのみ使用する.
  • -sf: dot (内積), general (線形変換), concat (連結) から注意用のスコア関数を選択する.
  • -tf: 教師強制 (teacher forcing) 率.
class AttentionAgent(Agent):

    def _zeros_gen(self, bs=1):
        # Encoder hidden and memory cell start with 0 filled tensor
        # 略
        return bi_h0, bi_c0, h0, c0

    def _encode(self, x):
        x = self.embedding(x)
        x = torch.transpose(x, 0, 1) # x: time x batch x hidden
        if self.train_step:
            x = self.dropout(x)
        bi_h0, bi_c0, h0, c0 = self._zeros_gen(x.size(1)) # layer x batch x hidden      
        # 略
        
        if self.rnn_type == 'GRU':
            # 略
        
        elif self.rnn_type == 'LSTM':
            hn, cn = None, None
            if self.use_bi_encoder:
                x, (hn, cn) = self.bi_encoder(x, (bi_h0, bi_c0))
                hn = transform(hn)
                cn = transform(cn)
            if self.use_encoder:
                bi_hn, bi_cn = hn, cn
                x, (hn, cn) = self.encoder(x, (h0, c0))
            if self.use_bi_encoder and self.use_encoder:
                hn = torch.cat((bi_hn, hn), dim=0)
                cn = torch.cat((bi_cn, cn), dim=0)

        if self.train_step:
            x = self.dropout(x)
        return x, hn, cn

    def _sos_gen(self, bs=1):
        # デコーダは SOS テンソルから開始する.
        # 略
        return x # x: 1 x batch x hidden

    def _attention(self, query, memory):
        # query: batch x hidden
        # memory: time x batch x hidden
        value = memory.transpose(0, 1) # value: batch x time x hidden
        key = value.transpose(1, 2) # key: batch x hidden x time 

        def dot(q, k):
            return torch.bmm(q.unsqueeze(1), k)

        def general(q, k):
            return torch.bmm(self.W_atte(q).unsqueeze(1), k)

        def concat(q, k):
            # 略

        func = {'dot': dot, 'general': general, 'concat': concat}
        score = func[self.score_func](query, key).squeeze(1) # batch x time
        attn_weight = self.softmax_atte(score).unsqueeze(1) # batch x 1 x time
        context = torch.bmm(attn_weight, value).squeeze(1) # batch x hidden
        x = torch.cat((context, query), dim=1) # batch x hidden*2
        x = self.tanh_cont(self.W_cont(x)) # batch x hidden
        return x

    def _decode_step(self, x, hn, cn, memory):
        # エンコーダの出力を注意用のメモリに使用する.
        if self.rnn_type == 'GRU':
            x, hn = self.decoder(x, hn) # x: 1 x batch x hidden
        elif self.rnn_type == 'LSTM':
            x, (hn, cn) = self.decoder(x, (hn, cn))
        x = x.squeeze(0) # x: batch x hidden
        if self.use_attention:
            x = self._attention(x, memory)
        if self.train_step:
            x = self.dropout(x)
        x = self.projection(x)
        x = self.softmax(x) # x: batch x vocab
        return x, hn, cn

    def _decode_and_train(self, memory, hn, cn, y):
        bs = memory.size(1)
        x = self._sos_gen(bs)
        preds = [[] for _ in range(bs)]
        self.longest_label = max(self.longest_label, y.size(1))
        self.attn_weight = []
        loss = 0

        for i in range(y.size(1)):
            x, hn, cn = self._decode_step(x, hn, cn, memory)
            t = y.select(1, i) # t: batch, select(dim, index)
            loss += self.loss(x, t)
            _, x = x.max(1) # x: batch x 1
            x = x.view(-1)
            # 予測を格納する.
            for j in range(bs):
                token = self.vec2txt([x.data[j]])
                preds[j].append(token)
            # 次のタイムステップの入力を準備する.
            if random.random() < self.teacher_forcing_rate:
                x = self.embedding(t).unsqueeze(0) # 1 x batch x hidden
            else:
                x = self.embedding(x).unsqueeze(0) # 1 x batch x hidden

        self.zero_grad()
        loss.backward()
        self.update_params()
        return preds

    def _decode_only(self, memory, hn, cn):
        # 略
        return preds

    def train(self, x, y):
        self.train_step = True
        if self.use_encoder:
            self.encoder.train()
        self.decoder.train()
        
        x, hn, cn = self._encode(x)
        preds = self._decode_and_train(x, hn, cn, y)
        return preds

    def batchify(self, obs):
        """バッチサイズの観測 `text`, `label` をランク 2 のテンソル `xs`, `ys` に変換."""
        # 略
        return x, y, ids, valid_inds

いろいろ比較してみた.

f:id:Ryobot:20171019192133p:plain

f:id:Ryobot:20171019192122p:plain

f:id:Ryobot:20171019192238p:plain

f:id:Ryobot:20171019192204p:plain

注意を可視化した.横軸はターゲット単語,縦軸はソース単語である.ちなみに注意には失敗している.bAbI タスクを seq2seq with Attention で解くのは困難っぽい.

f:id:Ryobot:20171019191948p:plain

得られた知見は次の通り.

  • Dropout はなくてもよい.
  • seq2seq を使う場合,Bidirectional Encoder は必須である.また 2, 3 層がよい.
  • LSTM より GRU の方が学習が早く進む.しかし,最終的に到達する精度は LSTM の方が高い.
  • 合同訓練 (Joint Training) では,簡単なタスクから学習が進み,難しいタスクは学習が遅いか,見捨てられる (タスク 19 で顕著).
  • seq2seq に注意 (Attention) を追加すると性能が悪くなる.また Luong のスコア関数は dot, general, concat によって注意に当たり方は全く異なる.

6. Chainer で MemN2NAgent (End-To-End Memory Networks) を実装する

目次に戻る ↩︎

メモリネットワーク (Memory Network) は,「情報の保持と処理を行うメモリ」という機能を持つニューラルネットワークである.メモリネットワークは質疑応答のような対話タスクにおいて回答に必要な情報を選択するために使われる.

End-To-End Memory Networks (MemN2N) [Sukhbaatar, NIPS 2015] の内部では以下のような動作をする.

  1. Bag of Words 表現の質問文 q を行列 B で埋め込みベクトルに変換し,クエリ u を求める.
  2. Bag of Words 表現の事実文の配列 {x} を行列 A 及び行列 C で埋め込みベクトルの配列 {m} 及び {c} に変換し,Input メモリと Output メモリに格納する.
  3. ベクトル u と行列 {m} の内積を Softmax 関数に通し,注意の重み p を求める.
  4. ベクトル p と行列 {c} の内積によってメモリの出力 o を求める.
  5. ベクトル u とベクトル o を加算して新しいクエリ u を求める.
  6. 2〜5 の操作を 1 Hop と呼び,必要な Hops 数だけ繰り返す.(デフォルトは 3 Hops)
  7. 最後のベクトル u と行列 W の内積を Softmax 関数に通し,回答単語の確率分布 a を求める.確率が最大の単語を Augmax 関数で選択し,MemN2N の予測とする.

MemN2N の全体像は下図の通り.

f:id:Ryobot:20171019192139p:plain

とどのつまり,MemN2N は複数回の Hop によって与えられた事実文の中から質問に関連する情報 (サポート文と呼ぶ) を適切に選択し解答を予測するモデルである.

MemN2N 固有の追加引数は次の通り.

  • -ms: メモリ層 (key と value) のサイズ (メモリに格納する文の数).
  • -nl: メモリの層数 (ホップ数)
  • -wt: Adjacent (隣接), Layer-wise (RNN ライクな層方向), Nothing (なし) から重み共有のタイプを選択する.
  • -pe: True なら,単語の埋め込みに位置エンコーディング (Position Encoding) を使用する.
  • -te: True なら,メモリへの文の格納に時間エンコーディング (Temporal Encoding) を使用する.
  • -rn: True なら,時間エンコーディング正則化にランダムノイズ (Random Noise) を使用する.
  • -ls: True なら,線形開始 (Linear Start) を使用する (メモリ層からソフトマックス関数を取り除く).
  • -opt: SGD, AdaGrad, Adam から最適化アルゴリズムを選択する.
class MemN2NAgent(Agent, chainer.Chain):

    def _position_encoding(self, sentence):
        # 位置エンコーディングのためにテンソル l を用意する.
        # 略
        return l

    def _embedding(self, embed, sentence, l=None): # batch x memory x sequence
        # `sentence` が q (質問) 由来のとき,メモリサイズのランクがない.
        e = embed(sentence) # batch x memory x sequence x hidden
        
        # 位置エンコーディング
        if self.use_position_encoding:
            # 略

        # 負の `axis` を指定すると,`sentence` が x 由来でも q 由来でも同じランクを指定できる.
        e = F.sum(e, axis=-2) # batch x memory x hidden
        return e

    def _attention(self, u, x, l, i): # batch x hidden
        # A と C 埋め込み行列のために使用する.
        # TA と TC 時間エンコーディングのために使用する.

        # 重み共有なし
        if self.weight_tying == 'Nothing':
            # 略

        # 隣接重み共有
        elif self.weight_tying == 'Adjacent':
            # 略
        
        # RNN ライクな層方向重み共有
        elif self.weight_tying == 'Layer-wise':
            A = self.embeddings[0][1]
            C = self.embeddings[1][1]
            TA = self.temp_encs[0][1]
            TC = self.temp_encs[1][1]  

        m = self._embedding(A, x, l) # m: batch x memory x hidden
        c = self._embedding(C, x, l)

        # 時間エンコーディング
        if self.use_temporal_encoding:
            # 略
        
        c = F.swapaxes(c, 2, 1) # c: batch x hidden x memory
        p = F.batch_matmul(m, u) # p: batch x memory x 1

        # 線形開始
        if not self.use_linear_start:
            p = F.softmax(p)

        o = F.batch_matmul(c, p) # o: batch x hidden x 1
        o = o[:, :, 0] # o: batch x hidden

        # RNN ライクな層方向重み共有
        if self.weight_tying == 'Layer-wise':
            u = self.H(u)

        u = o + u
        return u # batch x hidden

    def _forward(self, x, q, drop=False):
        bs = len(x.data)
        nl = self.num_layers
        l = None

        if self.use_position_encoding:
            l = self._position_encoding(q)
        u = self._embedding(self.B, q, l)

        if self.use_position_encoding:
            l = self._position_encoding(x)
        for i in range(nl):
            u = self._attention(u, x, l, i)
        
        u = self.double(u)
        us = F.split_axis(u, 2, axis=1)
        xs = [F.linear(u, self.W.W) for u in us] # xs: [batch x vocab, batch x vocab]

        preds = [[] for _ in range(bs)]
        ids = [F.argmax(x, axis=1) for x in xs] # ids: [batch x 1, batch x 1]
        for i in range(2):
            for j in range(bs):
                token = self.vec2txt([ids[i][j].data])
                preds[j].append(token)

        return xs, preds

    def train(self, x, q, y):
        with chainer.using_config('train', True):
            xs, preds = self._forward(x, q, drop=True)
            loss = 0
            y = F.transpose(y, axes=(1, 0)) # y: 2 x batch
            for i in range(2):
                loss += self.loss(xs[i], y[i])
            self.zero_grad()
            loss.backward()
            self.update_params()
        return preds

    def batchify(self, obs):
        """バッチサイズの観測 `text`, `label` を
        ランク 3 のテンソル `x` とランク 2 のテンソル `q`, `y` に変換."""
        # 略
        return x, q, y, ids, valid_inds

いろいろ比較してみた.各テクニックの詳細は別途記事にて紹介する.

f:id:Ryobot:20171019192038p:plain

f:id:Ryobot:20171019192009p:plain

f:id:Ryobot:20171019192023p:plain

f:id:Ryobot:20171019192002p:plain

f:id:Ryobot:20171019192233p:plain

損失の減少をプロットした.2 枚のうち上図は位置エンコーディング (PE) と時間エンコーディング (TE) を使用せず,下図は使用している.PE と TE を使用することで損失が小さくなっていることが窺える.

f:id:Ryobot:20171019192210p:plain

f:id:Ryobot:20171019192032p:plain

メモリ層の注意を可視化した.横軸はメモリの層方向 (下図では 5 層),縦軸はメモリに格納した文方向 (最後の入力文から 8 文目まで表示) である.概ね注意に成功していることが窺える.  

f:id:Ryobot:20171019192045p:plain

f:id:Ryobot:20171019192216p:plain

位置エンコーディング (Position Encoding) を可視化するとグラデーションが確認できる.

f:id:Ryobot:20171019192244p:plain

論文には記載されていない興味深い現象をいくつか発見した.

  • 初期値ガチャが強い.悪い初期値を引くと全く学習が進まないので見切りをつける.訓練精度 0 パーセント付近と 10~20 パーセント付近で停滞することが多い.
  • SGD より Adam が良い.個別訓練 (Single Training) では SGD でも学習可能だったが,Adam の方が収束が早く到達精度も良かった.また,合同訓練 (Joint Training) では SGD での学習に失敗した.
  • Adam の学習率 (alpha) は個別訓練では $0.005$,合同訓練 では $0.001$ が良い.また,Adam はアニーリング (検証精度の改善が止まる度に学習率を半減して訓練を再開する手法) が有効に働いた.
  • SGD は 1 Hop 目だけでサポート文を発見しようとする.2, 3 Hop 目は 0 ベクトルのサポート文 ('__NULL__ __NULL__ ...') を選択しやすい.Adam は 2 Hop 目以降も基づいてサポート文を発見している.
  • 入力文に過去の質問文を残したままにすると注意に失敗することがある.
  • タスク 19 のみ正解ラベルが 2 単語 (チャンスレートは 1/16) で精度が著しく悪い.

論文の結果を超えてゆけ!

本稿の実装と著者の オリジナルの Matlab コード と論文記載のテスト精度を比較した.

デフォルト設定: 3 ホップ, 位置エンコーディング (PE), 時間エンコーディング (TE), 線形開始 (LS), ランダムノイズ (RN), 隣接重み共有.

  • 本稿 (this repo): 重み共有なし, 線形開始なし, 5 ホップ, 隠れ層サイズ 128, Adam (アニーリング付き). 線形開始をすると訓練に失敗した. (1 試行のみ)
  • Matlab: 参考
  • 論文: 異なるランダムな初期値で 10 試行繰り返し,訓練エラーが最も低かったもでるを選択しているので精度が良い.参考

f:id:Ryobot:20171019192105p:plain

bAbI タスクではテスト精度が 95 %以上で成功と見做される.

論文記載の最高スコアでは 14/20 タスクに成功している.
本稿の実装 (最高の設定) では 15/20 タスクに成功した.

ParlAI で扱えるタスク一覧

目次に戻る ↩︎

ParlAI で提供されるデータセットのうち標本数が多く語彙数が少ないお勧めデータセット

  • bAbI Task (標本数: 200k, 語彙数: 0.2k): QA 界の MNIST (簡単で扱いやすい).出題形式が画一的で語彙数が極端に少ない.
  • Children’s Book Test (標本数: 670k, 語彙数: 60k): 児童書を使用した穴埋め問題.語彙数が少なく標本数が多いので扱いやすい.
  • WikiMovies dataset (標本数: 100k, 語彙数: 30k): 映画ドメインの QA.解答には外部資源が必要.
  • SQuAD (標本数: 90k, 語彙数: 110k): 入力文に解答が必ず含まれる.F1 スコアの上位は Pointer Network を使用.

質疑応答 (QA)

タスク一覧のトップに戻る ↩︎

質疑応答 (Question Answering) は解答が一単語ないし少数単語である.

text が短い標本が多く,その場合,感覚的に 1 標本 100 バイト前後 (1 文字 1 バイト).
PyTorch では weight の各成分は 4 バイトのメモリを消費する.gradient も同様.参照
語彙数は --dict-maxexs 1000000000 で算出した.

bAbI

QA 界の MNIST (簡単で扱いやすい).出題形式が画一的で語彙数が極端に少ない.

プログラムで自動生成された 20 種類のタスクから成る.以下,標本数・語彙数はシングルタスクの場合.

1k

特徴: 標本数が少ない.

訓練データの標本数 $900$, 語彙数 $21$ (bAbI/tasks_1-20_v1-2/en-valid-nosf/qa4_train.txt, 104.0 kB)
検証データの標本数 $100$ (bAbI/tasks_1-20_v1-2/en-valid-nosf/qa4_valid.txt, 11.5 kB)
テストデータの標本数 $1000$ (bAbI/tasks_1-20_v1-2/en-valid-nosf/qa4_test.txt, 115.4 kB)

[babi:Task1k:4]: The office is north of the kitchen.
The bathroom is north of the office.
What is north of the kitchen?
[labels: office]
[cands: office|bedroom|hallway|garden|bathroom|kitchen]

10k

特徴: 語彙数が少ない.標本の text が短い.

訓練データの標本数 $9000$, 語彙数 $21$ (bAbI/tasks_1-20_v1-2/en-valid-10k-nosf/qa4_train.txt, 1.0 MB)
検証データの標本数 $1000$ (bAbI/tasks_1-20_v1-2/en-valid-10k-nosf/qa4_valid.txt, 115.6 kB)
テストデータの標本数 $1000$ (bAbI/tasks_1-20_v1-2/en-valid-10k-nosf/qa4_test.txt, 115.4 kB)

[babi:Task10k:4]: The hallway is west of the bedroom.
The bedroom is west of the kitchen.
What is the bedroom east of?
[labels: hallway]
[cands: hallway|garden|kitchen|bedroom|office|bathroom]

MCTest

特徴: 標本数が少ない.標本の text が長い.

訓練データの標本数 $1200$, 語彙数 $4219$ (MCTest/train500.txt, 1.5 MB)
検証データの標本数 $200$ (MCTest/valid500.txt, 251.9 kB)
テストデータの標本数 $600$ (MCTest/test500.txt, 733.5 kB)

[mctest]: Tom was the best baseball player in his neighborhood. He also enjoyed playing basketball, but he wasn't very good at it.
# 中略
This made Tom happy again, and as he left to be with his team, he knew that he had finally made it. multiple: What sports did Tom play?
[labels: Baseball and basketball]
[cands: Baseball and basketball|Baseball and soccer|Baseball only|Baseball and football]

Movie Dialog

QA

特徴: 標本の text が短い.

訓練データの標本数 $96185$, 語彙数 $31879$ (MovieDialog/movie_dialog_dataset/task1_qa/task1_qa_train.txt, 6.7 MB)
検証データの標本数 $9968$ (MovieDialog/movie_dialog_dataset/task1_qa/task1_qa_dev.txt, 693.2 kB)
テストデータの標本数 $9952$ (MovieDialog/movie_dialog_dataset/task1_qa/task1_qa_test.txt, 695.5 kB)

[moviedialog:Task:1]: who is the director for The First Texan?
[labels: Byron Haskin]
[cands: Alyson Court|Ann Rutherford|Nathaniel Kahn|John Huddles|Noam Murro| ...and 75536 more]

Recommendations

特徴: 標本数が多い.標本の text が短い.語彙数がとても少ない

訓練データの標本数 $1000000$ (1m), 語彙数 $8015$ (MovieDialog/movie_dialog_dataset/task2_recs/task2_recs_train.txt, 185.9 MB)
検証データの標本数 $10000$ (MovieDialog/movie_dialog_dataset/task2_recs/task2_recs_dev.txt, 1.9 MB)
テストデータの標本数 $10000$ (MovieDialog/movie_dialog_dataset/task2_recs/task2_recs_test.txt, 1.9 MB)

[moviedialog:Task:2]: The Usual Suspects, Oldboy, The Shawshank Redemption, Amélie, Downfall, Pulp Fiction, and Schindler's List are movies I like. Would you suggest a film?
[labels: Hotel Rwanda]
[cands: Jonathan Smith|Brie Howard|Week-End in Havana|Jiseon Kim|The Great McGinty| ...and 75536 more]

MTurk WikiMovies

特徴: 標本の text が短い.

訓練データの標本数 $66307$, 語彙数 $37503$ (MTurkWikiMovies/mturkwikimovies/qa-train.txt, 5.7 MB)
検証データの標本数 $9173$ (MTurkWikiMovies/mturkwikimovies/qa-train.txt, 832.0 kB)
テストデータの標本数 $7848$ (MTurkWikiMovies/mturkwikimovies/qa-train.txt, 707.7 kB)

[mturkwikimovies]: The Black Camel aired in what year?
[labels: 1931]
[cands: James Naughton|Viviana Aliberti|Robert J. Wilke|outsiders|Vittorio Manunta| ...and 75536 more]

Simple Questions

特徴: 標本の text が短い.語彙数がやや多い.

訓練データの標本数 $75910$, 語彙数 $63734$ (SimpleQuestions/sq/train.txt, 4.4 MB)
検証データの標本数 $10845$ (SimpleQuestions/sq/valid.txt, 632.5 kB)
テストデータの標本数 $21687$ (SimpleQuestions/sq/test.txt, 1.3 MB)

[simplequestions]: which record label does slade belong to?
[labels: universal music group]

SQuAD

入力文に解答が必ず含まれる.F1 スコアの上位は Pointer Network を使用.

特徴: 標本の text はやや長い.語彙数がやや多い.データは json

訓練データの標本数 $87599$, 語彙数 $115948$ (SQuAD/train-v1.1.json, 30.3 MB)
検証データの標本数 $10570$ (SQuAD/dev-v1.1.json, 4.9 MB)
テストデータはない.

[squad]: Throughout its history, the city has been a major port of entry for immigrants into the United States; more than 12 million European immigrants were received at Ellis Island between 1892 and 1924. The term "melting pot" was first coined to describe densely populated immigrant neighborhoods on the Lower East Side. By 1900, Germans constituted the largest immigrant group, followed by the Irish, Jews, and Italians. In 1940, whites represented 92% of the city's population.
What percentage of the population was Caucasian in 1940?
[labels: 92%]

TriviaQA

証拠文書 (evidence document) を参照して回答する QA.

特徴: 標本数が多い (650K question-answerevidence triples), データは json でロードに少なくとも 1 時間かかる (中断ゆえ不明).

TriviaQA ディレクトリで 7.4 GB
訓練データ (TriviaQA/qa/web-train.json, 365.4 MB)

Web Questions

特徴: 標本数が少ない.標本の text が短い.語彙数がやや多い.

訓練データの標本数 $3778$, 語彙数 $10192$ (WebQuestions/train.txt, 303.0 kB)
検証データの標本数 $3778$ (WebQuestions/valid.txt)
テストデータの標本数 $2032$ (WebQuestions/test.txt)

[webquestions]: who is lamar odom married too?
[labels: Khloé Kardashian]

WikiMovies

映画ドメインの QA.解答には外部資源が必要.

データセットの中身は Movie Dialog の QA と全く同じ

特徴: 標本の text が短い.

訓練データの標本数 $96185$, 語彙数 $31879$ (WikiMovies/movieqa/questions/wiki-entities/wiki-entities_qa_train.txt, 6.7 MB)
検証データの標本数 $10000$ (WikiMovies/movieqa/questions/wiki-entities/wiki-entities_qa_dev.txt, 694.7 kB)
テストデータの標本数 $9952$ (WikiMovies/movieqa/questions/wiki-entities/wiki-entities_qa_test.txt, 695.5 kB)

[wikimovies]: who is the director for The First Texan?
[labels: Byron Haskin]
[cands: Billy Magnussen|John Connon|Ordinary Decent Criminal|Andre Gregory|Martina García| ...and 75536 more]

WikiQA

特徴: 標本数が少ない.標本の text が長い.語彙数は多い.

訓練データの標本数 $873$, 語彙数 $8646$ (WikiQA/train-filtered.txt, 1.4 MB)
検証データの標本数 $126$ (WikiQA/valid-filtered.txt, 177.3 kB)
テストデータの標本数 $243$ (WikiQA/test-filtered.txt, 374.2 kB)

[wikiqa]: what is metal music about
[labels: With roots in blues rock and psychedelic rock , the bands that created heavy metal developed a thick, massive sound, characterized by highly amplified distortion , extended guitar solos, emphatic beats, and overall loudness.|Heavy metal lyrics and performance styles are generally associated with masculinity and machismo .]
[cands: Heavy metal (often referred to as metal) is a genre of rock music that developed in the late 1960s and early 1970s, largely in the United Kingdom and in the United States.|
# 中略
|Since the mid-1990s, popular styles such as nu metal , which often incorporates elements of grunge and hip hop ; and metalcore , which blends extreme metal with hardcore punk , have further expanded the definition of the genre.]

InsuranceQA

保険の QA データセット.検証データとテストデータのファイルサイズがなぜか大きくて開けない...

特徴: 標本の text が短く,labels が長い.

訓練データの標本数 $12887$, 語彙数 $18697$ (InsuranceQA/V1/train.txt, 10.4 MB)
検証データの標本数 $1000$ (InsuranceQA/V1/valid.txt, 264.5 MB)
テストデータの標本数 $?$ (InsuranceQA/V1/test.txt, 475.9 MB)

[insuranceqa]: where can I get a free car insurance quote
[labels: the good place get a free car insurance quote be from an experienced broker 20 year of experience or more
# 中略
if you need help please feel free contact me thanks for ask|all quote be free no matter who quote you for insurance or do it online no carrier or agent shall charge just quote you if this have happen to you then most likely the be not legitimate if this have happento you need contact your state insurance commissioner and report them]

MS_MARCO

普通に実行すると語彙数が 62 万を超えて埋め込み層やソフトマックス層が 1GPU のメモリに乗り切らない.

特徴: 標本の text が長い.語彙数がとても多い.

訓練データの標本数 $82326$, 語彙数 $620434$ (MS_MARCO/train.txt, 380.7 MB)
検証データの標本数 $10047$ (MS_MARCO/valid.txt, 46.3 MB)
テストデータの標本数 $9676$ (MS_MARCO/test.txt, 42.0 MB)

[ms_marco]: The cost of sales, also referred to as the cost of goods sold, is a measure of how much it costs a company to sell its products. The cost of sales is also a necessary step when a business is trying to determine the amount of gross profit made in a given period. Not every company calculates cost of sales the same way.
# 中略
what is cost of sales
[labels: The accumulated total of all costs used to create a product or service, which has been sold.]

穴埋め (Cloze)

タスク一覧のトップに戻る ↩︎

穴埋め (Cloze) を一言でいうなら言語理解が必要な Continuous Bag of Words である.周辺の文からターゲットの単語を文脈的に予測するが,単純な周辺単語から尤度で予測できないようにタスクが設計されている.

BookTest

-t booktest:Stream

[creating task(s): booktest]
_ POEMS AND LYRICS OF THE JOY OF EARTH .
# 中略
V. Archduchess Anne the Council ruled , Count XXXXX his great dame ; And woe to both when one had cooled !
[labels: Louis]

Children’s Book Test (CBT)

児童書を使用した穴埋め問題.語彙数が少なく標本数が多いので扱いやすい.

特徴: 標本数が多い.標本の text が長い.語彙数がとても少ない

訓練データの標本数 $669343$, 語彙数 $60835$ (CBT/CBTest/data/cbtest_NE_train.txt, 248.3 MB, cbtest_CN_train.txt, 295.9 MB, cbtest_V_train.txt, 247.1 MB, cbtest_ P_train.txt, 836.8 MB)
検証データの標本数 $8000$ (CBT/CBTest/data/cbtest_NE_valid_2000ex.txt, 4.3 MB)
テストデータの標本数 $10000$ (CBT/CBTest/data/cbtest_NE_test_2500ex.txt, 5.6 MB)

[cbt:NE]: Fill in the blank in the last sentence.
`` But last of all , '' they said , `` came one in silver armor , and he had a silver bridle on his horse , and a silver saddle , and oh , but he could ride !
He took his horse two-thirds of the way up the hill , but then he turned back .
# 中略
`` Oh , how I should have liked to see him too ! ''
said XXXXX .
[labels: Cinderlad]
[cands: Cinderlad|Princess|day|days|earth|fellow|night|second|sight|third]

QA CNN

CNN の記事の Cloze データセット

特徴: 標本数が多い.標本の text が長い.

訓練データの標本数 $380298$, 語彙数 $118331$ (QACNN/train.txt, 1.7 GB)
検証データの標本数 $3924$ (QACNN/valid.txt, 17.5 MB)
テストデータの標本数 $3198$ (QACNN/test.txt, 13.4 MB)

[qacnn]: ( @entity0 ) -- here is what the election next year is about : the fence - sitters , the independent voters . at this point , there is not much president @entity7 can say that will win over conservatives , and given the current @entity11 field , he does n't have to worry too much about losing liberals .
# 中略
the opinions expressed in this commentary are solely those of @entity180 . @placeholder : an @entity7 - @entity38 re-election ticket would give new life to campaign
[labels: @entity180]
[cands: @entity15|@entity0|@entity7|@entity164|@entity11| ...and 15 more]

QA Daily Mail

Daily Mail の記事の Cloze データセット.論文は QA CNN と同様.

訓練データの標本数 $879450$, 語彙数 $207420$ (QADailyMail/train.txt, )

[qadailymail]: by @entity0 published : 14:17 est , 16 may 2012 updated : 14:34 est , 16 may 2012 guilt : @entity2 was convicted of war crimes and crimes against humanity former @entity6 president @entity2 begged for leniency ahead of his sentencing for a catalogue of brutal war crimes saying he has sympathy for @entity12 ’s civil war .
# 中略
while the @entity12 court is formally based in that country 's capital , @entity2 's trial is being staged in @entity215 , a suburb of @entity24 , @entity216 , for fear holding it in @entity70 could destabilize the region . former @entity6 president @placeholder begs for leniency before sentencing at the @entity24 for war crimes
[labels: @entity2]
[cands: @entity22|@entity21|@entity24|@entity137|@entity150| ...and 21 more]
   [RepeatLabelAgent]: @entity2

ゴール指向 (Goal)

タスク一覧のトップに戻る ↩︎

Dialog Based Language Learning

bAbI Task

[reward: 0]
[dbll_babi]: Mary went back to the kitchen.
Sandra travelled to the bathroom.
Where is Mary?
[labels: kitchen]
[cands: kitchen|bathroom|office|garden|bedroom|hallway]
   [RepeatLabelAgent]: kitchen
~~
[reward: 0]
[dbll_babi]: Yes, that is correct!

WikiMovies Task

[reward: 0]
[dbll_movie]: what does America Ferrera appear in?
[labels: One Last Thing...]
[cands: Frank Giering|Near Dark|Dylan Schaffer|Valentin Vinogradov|Cristina Banegas| ...and 75537 more]
   [RepeatLabelAgent]: One Last Thing...
~~
[reward: 0]
[dbll_movie]: No, that's wrong.

Dialog bAbI

[dialog_babi:Task:6]: can i have a restaurant in the south part of town
[labels: What kind of food would you like?]
[cands: here it is resto_seoul_moderate_cantonese_7stars_phone|what do you think of this option: resto_seoul_cheap_vietnamese_1stars|here it is resto_beijing_expensive_thai_8stars_address|what do you think of this option: resto_paris_expensive_italian_5stars|here it is resto_seoul_moderate_korean_6stars_phone| ...and 4208 more]

Movie Dialog QA Recommendations

論文は Movie Dialog と同様.

[moviedialog:Task:3]: I loved Rising Sun, Glory, The Mummy Returns, Radio, The Day After Tomorrow, The One, and Indiana Jones and the Last Crusade. I'm looking for an interesting movie.
[labels: The Fast and the Furious]
[cands: death of a spouse|James Ellison|Horton Hears a Who!|Faycal Attougui|Julie O'Hora| ...and 75536 more]

Personalized Dialog

Full Set

[personalized_dialog:FullTask:1]: male young
hello
[labels: hey dude what is up]
[cands: here is the information you asked for resto_london_moderate_british_4stars_1_social_media|is this one cool: resto_beijing_expensive_thai_2stars_2|here it is resto_beijing_expensive_vietnamese_8stars_1_address resto_beijing_expensive_vietnamese_8stars_1_public_transport|here you go resto_tokyo_moderate_vietnamese_1stars_1_phone|here is the information you asked for resto_london_cheap_indian_6stars_2_address resto_london_cheap_indian_6stars_2_public_transport| ...and 43858 more]

Small Set

[personalized_dialog:SmallTask:1]: female middle-aged
hi
[labels: hello maam how can i help you]
[cands: what do you think of this option: resto_rome_moderate_indian_7stars_2|here is the information you asked for resto_london_moderate_indian_2stars_2_address resto_london_moderate_indian_2stars_2_parking|here it is resto_seoul_expensive_korean_1stars_1_address resto_seoul_expensive_korean_1stars_1_public_transport|is this a good option: resto_london_moderate_british_2stars_1|here is the information you asked for resto_tokyo_moderate_vietnamese_5stars_1_address resto_tokyo_moderate_vietnamese_5stars_1_public_transport| ...and 43858 more]

雑談 (ChitChat)

タスク一覧のトップに戻る ↩︎

Cornell Movie

[cornell_movie]: The question is, who are you? You are in the darkness, but it's not your fault. Elijah Muhammad can bring you into the light.
[labels: Elijah who?]
[cands: Yes, you both think William Bloom is a very smart man.  The problem is, you only see me as your mother, and not as someone's wife. And I've been his wife longer than I've been your mother. You can't discount that.|Very simple, Earl.|Barely.|NI!   NI!|You call them...| ...and 97145 more]

Movie Dialog Reddit

論文は Movie Dialog と同様.

[moviedialog:Task:4]: I remember I enjoyed the aesthetic as well , but I think more of it had to do with the set / costume / makeup design than the cinematography
[labels: And that Easy like a Sunday morning scene . Haha .]

Open Subtitles

[opensubtitles]: [ Thunder rumbles , crashes ]
[labels: [ Sea gulls crying ]]
[cands:  Ladies and gentlemen , please step to the rear .| This is comfy .| Oh , they don\xc2\xb4t sell condoms that big .| - I have a message for Robert Charles Ryan ... ... soon to be ex- owner of the Standard Hotel .| Iay about . or my fist will put you out| ...and 745301 more]

Ubuntu

[ubuntu]: the smb.conf is for smb/cifs shares that it is controlling, it has no effect on cifs shares that you are a client for .
 That's what I figured. So that means, speeding up cifs mounts will need to be done elsewhere. .
 correct.  chances the problem is your network or disk IO .
 I don't know if that's the case, since I can copy to both of these systems at the exact same rate. .
[labels: so what is slow?  I just got back so I missed most of the background .]

ビジュアル質疑応答 (Visual)

タスク一覧のトップに戻る ↩︎

ロードに極めて時間がかかる.

VQAv1

VQAv2

VisDial

MNIST_QA

[mnist_qa]: Which number is in the image?
[labels: 9|nine]
[cands: two|three|zero|7|eight| ...and 15 more]