機械学習周りのプログラミング中心。 イベント情報
ポケモンバトルAI本電子書籍通販中

【CodinGameオセロ】DNNの方策だけで打つAIの実装

前回、DNNを教師あり学習しました。今回は学習したモデルをAIの探索に組み込みます。まずは、局面の先読みをせずに方策(policy)の出力をそのまま打つAIを実装します。

今回の時点のコードです(前回記事と同じ)

https://github.com/select766/codingame-othello/tree/9cf25088aa2756f68bd4b3a096978d7842a0e0ce

プロセス間通信を用いた盤面評価

DNNの学習はPythonとTensorflowの組み合わせにより行いました。学習したモデルはTensorflow固有のSavedmodel形式で保存されています。これをC++で実装する探索部から呼び出すことが目的となります。TensorflowにはC++APIがあり、これを用いることが考えられますが、大きなランタイムライブラリが必要となるためCodinGameで動作させることは難しいと考えられます。まずはローカルで簡単に動作させることを優先し、新たに環境構築が必要となるC++APIは用いないことにしました。代わりに、TensorflowのモデルをPythonのコードでロードし、Pythonの動作しているプロセスと、探索部を持つC++のプロセスをプロセス間通信で連携させてDNNの実行を行うことにしました。

プロセス間通信の方式はTCPソケットです。Pythonプロセスがサーバとなり、事前にDNNをロードした状態で接続を待ちます。C++プロセスが起動するとクライアントとして接続します。C++プロセスが評価すべき盤面をPythonプロセスに送信し、PythonプロセスはTensorflowを用いてDNNを実行し、その結果をC++プロセスに送信します。

サーバのコードの主要部分を掲載します。エラー処理は省略していますので実際のコードを参照してください。Savedmodel形式にはモデルの構造が記録されているため、モデル定義クラスのロードは必要ありません。

import argparse
import socket
import numpy as np
import tensorflow as tf

def request_loop(model, sock):
    try:
        while True:
            # 評価対象盤面をTCPソケットから受信する
            board_array = read_input_array(sock)
            # DNNで評価
            predicted = model(board_array)
            # 結果をバイト列に変換してTCPソケットに送信
            policy_data = predicted[0].numpy()
            value_data = predicted[1].numpy()
            send_data = policy_data.tobytes() + value_data.tobytes()
            assert len(send_data) == OUTPUT_BYTE_LENGTH
            sock.sendall(send_data)
    except DisconnectedError: # 切断された(独自の例外クラス)
        pass


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("savedmodel_dir")
    # バッチサイズ1のため、CPU (/CPU:0)のほうがGPU (/GPU:0)より速いと予想される
    parser.add_argument("--device", default="/CPU:0")
    parser.add_argument("--port", type=int, default=8099)
    parser.add_argument("--host", default="")
    args = parser.parse_args()
    with tf.device(args.device):
        # モデルをロード
        model = tf.keras.models.load_model(args.savedmodel_dir)
        # TCPサーバを起動
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.bind((args.host, args.port))
        sock.listen(1)
        print("listening on port", args.port)
        while True:
            # クライアントの接続を待ち受ける
            conn, addr = sock.accept()
            print("connected by", addr)
            request_loop(model, conn)
            print("disconnected from", addr)

C++側ではクライアントを実装します。近い将来、C++プロセス内でDNNの評価を行う機構に切り替えることを想定していますのでインターフェースクラスDNNEvaluatorTCPソケットを用いた実装を提供するクラスDNNEvaluatorSocketに分離しています。C++側でビットボードから入力テンソルに変換する処理を実装しています。

class DNNEvaluatorResult
{
public:
    float policy_logits[BOARD_AREA];
    float value_logit;
};

class DNNEvaluatorRequest
{
public:
    float board_repr[BOARD_AREA * 3];
};

class DNNEvaluator
{
public:
    virtual DNNEvaluatorResult evaluate(const Board &board) = 0;

protected:
    // 入力テンソルの生成
    DNNEvaluatorRequest make_request(const Board &board)
    {
        DNNEvaluatorRequest req;
        memset(&req, 0, sizeof(req));
        // board_repr: [pos(y,x),3] (NHWC)
        const int n_ch = 3;
        for (int i = 0; i < N_PLAYER; i++)
        {
            int turn = i == 0 ? board.turn() : 1 - board.turn();
            BoardPlane bb = board.plane(turn);
            for (int pos = 0; pos < BOARD_AREA; pos++)
            {
                if (bb & (1ULL << pos)) // 当時はバグがあった
                {
                    req.board_repr[pos * n_ch + i] = 1.0F;
                }
            }
        }
        // fill 1
        for (int pos = 0; pos < BOARD_AREA; pos++)
        {
            req.board_repr[pos * n_ch + 2] = 1.0F;
        }

        return req;
    }
};

class DNNEvaluatorSocket : public DNNEvaluator
{
    int sock;

public:
    DNNEvaluatorSocket(const char *ip_addr, int port)
    {
        struct sockaddr_in addr;
        memset(&addr, 0, sizeof(addr));

        addr.sin_family = AF_INET;
        addr.sin_port = htons((unsigned short)port);
        addr.sin_addr.s_addr = inet_addr(ip_addr);

        sock = socket(AF_INET, SOCK_STREAM, 0);
    }

    ~DNNEvaluatorSocket()
    {
        if (sock >= 0)
        {
            close(sock);
            sock = -1;
        }
    }

    DNNEvaluatorResult evaluate(const Board &board)
    {
        DNNEvaluatorRequest req = make_request(board);
        auto send_size = send(sock, &req, sizeof(req), 0);

        DNNEvaluatorResult res;
        unsigned char *p = reinterpret_cast<unsigned char *>(&res);
        ssize_t remain_size = sizeof(res);
        while (remain_size > 0)
        {
            ssize_t recv_size = recv(sock, p, remain_size, 0);
            remain_size -= recv_size;
            p += recv_size;
        }

        return res;
    }
};

DNNEvaluatorは、以下のようにサーバのIPアドレスとポート番号を指定してインスタンス化します。

shared_ptr<DNNEvaluator> evaluator(new DNNEvaluatorSocket("127.0.0.1", 8099));

以下の行は、実装当時バグがありました。(1 << pos)と実装していましたが、32bit整数として処理されており64bitのビットボードを表現できていませんでした。なお、次節の実験結果では修正後の値を示しています。

if (bb & (1ULL << pos)) // 当時はバグがあった

Policy AIの実装

DNNEvaluatorを用いた探索部を実装します。今回はもっとも単純に、合法手のうちDNNのpolicyの出力が最大となる手を選択するPolicy AIを実装しました。ここまでの準備ができていれば実装は簡単です。

class SearchPolicy : public SearchBase
{
    shared_ptr<DNNEvaluator> dnn_evaluator;

public:
    SearchPolicy(shared_ptr<DNNEvaluator> dnn_evaluator) : dnn_evaluator(dnn_evaluator) {}
    Move search(string &msg)
    {
        vector<Move> move_list;
        board.legal_moves(move_list);
        // パス時の処理省略
            auto eval_result = dnn_evaluator->evaluate(board);
        // 合法手の中で最もpolicyの出力が大きかったものを選ぶ
            Move bestmove = 0;
            float bestlogit = -1000.0F;
            for (auto move : move_list)
            {
                // eval_result.policy_logitsには64マス分のpolicyの値が含まれる
                float policy_logit = eval_result.policy_logits[move];
                if (bestlogit < policy_logit)
                {
                    bestlogit = policy_logit;
                    bestmove = move;
                }
            }

            return bestmove;
        }
    }
};

ランダムAIとPolicy AIを自己対戦させたところ、Policy AIの勝率が78%となりました。値は決して高くありませんが、学習データの生成、学習、探索部での利用の一連の流れが実現できたことを示しています。現在のところモデルの評価にTensorflowが必要なため、CodinGame上では動作させることができません。しばらくの間はローカルでの動作確認のみで実装を進めることになります。次回は、MCTSによる探索部を実装します。