機械学習周りのプログラミング中心。 イベント情報
ポケモンバトルAI本電子書籍通販中

汎用行動選択モデルの学習 part04 教師あり学習【PokéAI】

前回作成した教師データを用いて、汎用行動選択モデルの学習を試みます。

select766.hatenablog.com

以下のようにバトルの状態を入力とし、適切な行動(技)を選択するモデルを学習することが目標です。

F マタドガス 187/187  
ころがる 10まんボルト でんじほう だいもんじ
O オムスター 193/193  
=> 10まんボルト

モデルはDeep Neural Network (DNN)を用います。DNNは畳み込み、リカレントなど構造の自由度が極めて高いですが、今回はもっとも単純にf(バトルの状態,選択肢iの情報) => 選択肢iの優先度という入出力の全結合feed-forward networkにすることとしました。

バトルの状態はパーティ固有モデルと同様で、以下のようになります*1

自分/相手は、どちら側のパーティの情報を入力として与えるか。両方の場合は次元数が倍となります。

特徴 次元数 自分/相手 説明
有効な行動 4 自分 現在どの行動がとれるのかを表す有効な行動に該当する次元に1を設定
生存ポケモン 1 両方 瀕死でないポケモン数/全ポケモン
ポケモンタイプ 17 相手 場に出ているポケモンポケモンのタイプ(ノーマル・水・…)に該当する次元に1を設定
HP残存率 1 両方 場に出ているポケモンの現在HP/最大HP
状態異常 6 両方 場に出ているポケモンの状態異常(どく・もうどく・まひ・やけど・ねむり・こおりのうち該当次元に1を設定)
ランク補正 6 両方 場に出ているポケモンのランク補正(こうげき・ぼうぎょ・とくこう・とくぼう・すばやさ・命中・回避それぞれ、ランク/12+0.5を設定)
天候 3 - 場の天候(はれ・あめ・すなあらし)に該当する次元に1を設定

この特徴量には自分のパーティの構成や選択肢の内容が入っていません。モデルに与える選択肢を表現するベクトルとして今回はもっとも単純に、ポケモンと技のone-hotベクトルを用いました。以下の表の各行が選択肢1つに対応するベクトルとなります。

自分のポケモンマタドガスの場合:

選択肢 ポケモン=アズマオウ ポケモン=マタドガス 技=10まんボルト 技=ころがる 技=だいもんじ 技=どくどく 技=でんじほう
技:ころがる 0 1 0 1 0 0 0
技:10まんボルト 0 1 1 0 0 0 0
技:でんじほう 0 1 0 0 0 0 1
技:だいもんじ 0 1 0 0 1 0 0

ポケモンは実際には251次元、技も251次元です。このベクトルとバトルの状態ベクトルを連結したもの、合計558次元をDNNへの入力とします。

教師あり学習をするにあたり、一般的な分類問題の定式化に落とし込みます。softmax(f(バトルの状態,選択肢0のベクトル),f(バトルの状態,選択肢1のベクトル),f(バトルの状態,選択肢2のベクトル),f(バトルの状態,選択肢3のベクトル))を各選択肢の選択確率として、正解データとのcross entropy lossを最小化するように学習します。

今回から、深層学習ライブラリはPyTorchに移行しました。

全結合feed-forward networkは次のようなものになります。

import torch.nn as nn
import torch.nn.functional as F


class MLPModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        h = x  # batch, feature_dim
        h = F.relu(self.fc1(h))
        h = self.fc2(h)
        return h

このモデルに選択肢ベクトルを変えて4回呼び出し、結果を連結して損失を計算することも可能ですが、実装上あまり効率が良くありません。 そこで、モデルの呼出しを1回ですべての計算を終えるテクニックがあります。それは、全選択肢に対応するベクトルを積み重ねた行列(558×4)を入力とし、全結合層をカーネルサイズ1の1D Convolutionに置き換えることです。これでも計算結果は等価になります。なぜそうなるかは「1x1 convolution」などで調べてみてください。

import torch.nn as nn
import torch.nn.functional as F


class MLPModel(nn.Module):
    def __init__(self, input_dim, n_layers=2, n_channels=64, bn=False):
        super().__init__()
        layers = []
        bn_layers = []
        cur_hidden_ch = input_dim
        for i in range(n_layers):
            layers.append(nn.Conv1d(cur_hidden_ch, n_channels, 1))  # in,out,ksize
            cur_hidden_ch = n_channels
            if bn:
                bn_layers.append(nn.BatchNorm1d(n_channels))
        self.layers = nn.ModuleList(layers)
        self.bn_layers = nn.ModuleList(bn_layers)
        self.output = nn.Conv1d(cur_hidden_ch, 1, 1)
        self.bn = bn

    def forward(self, x):
        h = x  # batch, feature_dim, 4
        for i in range(len(self.layers)):
            h = self.layers[i](h)
            if self.bn:
                h = self.bn_layers[i](h)
            h = F.relu(h)
        h = self.output(h)
        h = h.view(h.shape[0], -1)  # batch, 4
        return h

このように定義したモデルを学習させます。 パーティ1000個を1パーティ当たり100回他のパーティと対戦させ、10万バトル分の行動を得ました。同じバトルについて2つのパーティから見た状態を別のデータとして扱っています。900パーティ分のデータを学習データ、残り100パーティ分を評価(validation)データとして使用します。1回のバトルで複数のターンがあるので、学習データは34万サンプルとなりました。

上記のモデルのパラメータを変更して正解率を測定しました。学習は10エポック、OptimizerはAdam、lr=0.01、バッチサイズ256としました。層数は出力層以外のConvolutionの数です。チャンネル数は隠れ層の出力チャンネル数です。バッチ正規化は、各Convolutionの後にBatch Normalizationレイヤーを付加するか否かです。

層数 チャンネル数 バッチ正規化 Training正解率[%] Validation正解率[%]
1 16 False 55.2 52.5
1 16 True 67.9 55.6
1 64 False 58.2 54.6
1 64 True 72.5 55.2
1 256 False 63.8 53.8
1 256 True 75.1 53.1
2 16 False 65.6 54.2
2 16 True 68.6 56.4
2 64 False 70.5 55.9
2 64 True 73.2 54.6
2 256 False 70.1 55.4
2 256 True 76.2 53.2
3 16 False 66.2 58.5
3 16 True 67.9 56.0
3 64 False 68.2 57.8
3 64 True 73.6 54.4
3 256 False 68.6 54.8
3 256 True 75.9 53.6

Validation正解率が最大となるのは3層、チャンネル数16、バッチ正規化なしという結果になりました。パラメータ間に極端な差はないですが、層の数は多いほうが良い一方で、チャンネル数が多いとTraining正解率は高くなる一方でValidation正解率は下がってしまい過学習していることがわかります。バッチ正規化についても、過学習を誘発しているようです。今回の学習データ生成には計算時間が1日ほどかかるため、量を大幅に増加させることは難しいです。強化学習に移行するか、過学習しづらいようモデル構造を工夫することが将来課題です。

*1:3vs3バトルと同等のものを用いていますが、1vs1のため生存ポケモン数などは意味がありません。また有効な行動の数は技4つのみなので次元数が減っています