機械学習周りのプログラミング中心。 イベント情報
ポケモンバトルAI本電子書籍通販中

【CodinGameオセロ】並列対局を想定したMCTSの実装

AlphaZero式の強化学習において、学習データを作成するために必要なMCTSの実装を行います。

現時点のコードのバージョンはこちらです。

https://github.com/select766/codingame-othello/tree/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f

AlphaZeroの学習サイクル

AlphaZeroでは、以下のステップを繰り返すことでDNNモデルを強化学習します。

  1. AI同士の自己対局を行う。その際の手は、DNNモデルを評価関数としたMCTSにより決める。
  2. 自己対局で選んだ手およびその対局の勝敗を教師データとして、DNNモデルを教師あり学習する。

MCTSによる探索を行うことで、DNNモデルから出力されるpolicyよりも良い手が打てます。その手を直接policyの出力として学習させることにより、DNNがより強い手を選べるようになります。

MCTSの実装

今回は、自己対局で利用するMCTSを実装します。

MCTSの手続きは以下の通りです。MCTSでは局面をノード、手をエッジとする木構造を用います。

  1. ルート局面(打つ手を決めたい局面)に対応するルートノードを作成する。
  2. 以下に示す探索を一定回数または一定時間行う。
    1. ルートノードから所定の基準で子ノードを再帰的に選択し、葉ノードに到達する。
    2. 葉ノードの局面をDNNで評価する。
    3. 評価結果を、探索経路の各エッジに反映させる(バックアップ操作)。
  3. 手を決定する。
    1. ルートノードから子ノードにつながるエッジのうち、もっとも訪問回数が多いものを選択する。

自己対局では、現在の局面でMCTSにより手を決定し、その手で局面を一手進めます。その局面で再びMCTSにより決定し手を進めるということを終局まで行い、棋譜を出力します。

シンプルなMCTSの実装

まずは、手順がわかりやすいシンプルなMCTS実装を示します。上記のステップを素直に実装したものとなります(実装の時系列としては並行動作対応より後でした)。

https://github.com/select766/codingame-othello/blob/fa283be034455fddee3e97d13df89542e937bba9/src/search_mcts.hpp

この実装では、ゲーム木を再帰的にたどって到達した葉ノードにおいてDNN評価を dnn_evaluator->evaluate(b); のように呼び出しています。シングルコアCPUしか使用できない本番の対局ではここでDNNを実行すればよいのですが、学習時はGPUを用いて評価を行って高速化したいです。GPUでは1局面だけ評価するのは非効率なため、多数の局面をミニバッチにまとめて評価する仕組みが必要です。そのためには、バッチサイズNに対応したN対局を並行して進め、評価すべき局面をまとめます。対局を並行して行う実装は2通り考えられます。1つ目は、マルチスレッドを用いる方法です。対局ごとのスレッドと、DNNを評価するスレッドを用意します。対局スレッドは、DNNを実行する箇所に到達したらDNNスレッドに局面を送信し(全スレッドで共有されたメモリ・キューなどを利用)、評価が完了するまで待機します。DNNスレッドは、すべての対局スレッドから局面を受信したらDNNの評価を行い、結果を対局スレッドに返送します。2つ目は、シングルスレッドで全ての対局を順番に実行する方法です。この方法では、探索のコードを書き替えて、search関数の内部で評価を呼び出すのではなく、search関数の戻り値として評価すべき局面を返します。各対局のsearch関数を順番に呼び出して評価すべき局面を収集し、DNNの評価を行います。評価結果は、search関数の引数として探索部に返します。方法1のメリットは、マルチスレッドに起因する実装ミスがなければ見通しの良いコードになる点です。デメリットは、スレッドの切り替えが高頻度に発生する点です。毎秒数万回の切り替えが発生し、これがオーバーヘッドになる可能性があります。方法2のメリットは、実装ミスが起きやすいマルチスレッドの実装をしなくてよい点です。デメリットは、評価すべき局面に到達したら一度探索を中断して局面を戻り値として返し、評価結果を受け取ったら探索を再開するというコードの構造の書き換えが必要になる点です。今回は方法2を実装することとしました。

並行動作に対応したMCTSの実装

並行動作に対応したMCTSを実装します。外部から呼び出されるのはsearch_partial関数で、局面の評価が必要か、指し手が決定するまで探索を行います。この関数のシグネチャshared_ptr<SearchPartialResult> search_partial(const EvalResult *eval_result) となっているのがポイントです。SearchPartialResultは、探索完了(指し手決定時)か、局面の評価が必要になったときに返すデータで、unionのような構造になっています。呼び出し側は、dynamic_castでどちらの型が返ったかを判定します。SearchPartialResultMoveであればその指し手で対局を進めて再度search_partialを呼び出します。SearchPartialResultEvalRequestであれば局面を評価します。

class SearchPartialResult
{
public:
    virtual ~SearchPartialResult() = default;
};

class SearchPartialResultMove : public SearchPartialResult
{
public:
    Move move;
    float score;
};

class SearchPartialResultEvalRequest : public SearchPartialResult
{
public:
    Board board;                             // 評価対象局面
    vector<pair<TreeNode *, int>> tree_path; // 探索木の経路。TreeNodeと、その中のエッジのインデックスのペア。leafがルートの場合は空。
    TreeNode *leaf;                          // 末端ノードのTreeNode。
};

search_partial関数の中は以下のようになっています。メンバ変数next_taskに、次に実行すべき動作が入っており、これにより内部の関数を呼び分けます。各関数は、next_taskを書き換えたのち戻り値としてshared_ptr<SearchPartialResult>nullptrを返すようになっています。例えば、start_search関数は、探索に関するメンバ変数を初期化し、ルートノードを作成し、next_task = NextTask::ASSIGN_ROOT_EVALを実行し、ルートノードを評価するためのSearchPartialResultEvalRequestを返します。するとsearch_partial関数ではresult != nullptrとなるためループを抜け、呼び出し元にルートノードを評価するようリクエストが返ることになります。呼び出し元は局面を評価し、eval_resultに評価結果を代入した状態でsearch_partialを再度呼び出します。するとnext_task == NextTask::ASSIGN_ROOT_EVALとなっているため、次はルートノードの評価結果をルートノードに代入する処理が実行されます。

少々複雑な仕組みですが、単一スレッドで複数の処理を並行動作させる「コルーチン」の実装は似たようなものになるようです。C++20ではコルーチンをサポートする構文が導入されるため、これを利用すれば実装を簡略化できるかもしれません。

shared_ptr<SearchPartialResult> search_partial(const EvalResult *eval_result)
{
    shared_ptr<SearchPartialResult> result;
    do
    {
        switch (next_task)
        {
        case NextTask::START_SEARCH:
            result = start_search();
            break;
        case NextTask::ASSIGN_ROOT_EVAL:
            result = assign_root_eval(eval_result);
            break;
        case NextTask::SEARCH_TREE:
            result = search_tree();
            break;
        case NextTask::ASSIGN_LEAF_EVAL:
            result = assign_leaf_eval(eval_result);
            break;
        case NextTask::CHOOSE_MOVE:
            result = choose_move();
            break;
        }
    } while (!result);

    // 評価すべき局面として何を返したのか覚えておく
    prev_request = dynamic_pointer_cast<SearchPartialResultEvalRequest>(result);

    return result;
}

並行動作に関する部分だけ抽出しましたが、それでも長いです。

class SearchMCTSTrain : public SearchBase
{
public:
    // 探索完了(指し手決定時)か、局面の評価が必要になったときに返すデータ
    class SearchPartialResult
    {
    public:
        virtual ~SearchPartialResult() = default;
    };

    class SearchPartialResultMove : public SearchPartialResult
    {
    public:
        Move move;
        float score;
    };

    class SearchPartialResultEvalRequest : public SearchPartialResult
    {
    public:
        Board board;                             // 評価対象局面
        vector<pair<TreeNode *, int>> tree_path; // 探索木の経路。TreeNodeと、その中のエッジのインデックスのペア。leafがルートの場合は空。
        TreeNode *leaf;                          // 末端ノードのTreeNode。
    };

    class EvalResult
    {
    public:
        float value_logit;
        float policy_logits[BOARD_AREA];
    };

private:
    // 探索の次のタスクを表す
    enum NextTask
    {
        START_SEARCH,
        ASSIGN_ROOT_EVAL,
        SEARCH_TREE,
        ASSIGN_LEAF_EVAL,
        CHOOSE_MOVE,
    };
    NextTask next_task;
    int playout_count;

    TreeNode *root_node;

    shared_ptr<SearchPartialResultEvalRequest> prev_request;

public:
    void newgame()
    {
        tree_table->clear();
        root_node = nullptr;
        next_task = START_SEARCH;
    }

    // 局面の評価が必要か、指し手が決定するまで探索する
    shared_ptr<SearchPartialResult> search_partial(const EvalResult *eval_result)
    {
        shared_ptr<SearchPartialResult> result;
        do
        {
            switch (next_task)
            {
            case NextTask::START_SEARCH:
                result = start_search();
                break;
            case NextTask::ASSIGN_ROOT_EVAL:
                result = assign_root_eval(eval_result);
                break;
            case NextTask::SEARCH_TREE:
                result = search_tree();
                break;
            case NextTask::ASSIGN_LEAF_EVAL:
                result = assign_leaf_eval(eval_result);
                break;
            case NextTask::CHOOSE_MOVE:
                result = choose_move();
                break;
            }
        } while (!result);

        // 評価すべき局面として何を返したのか覚えておく
        prev_request = dynamic_pointer_cast<SearchPartialResultEvalRequest>(result);

        return result;
    }

private:
    shared_ptr<SearchPartialResult> start_search()
    {
        // 探索木の再利用をしないので、テーブルを初期化する。これによりテーブルのサイズは1手当たりのプレイアウト数+αだけで済む。
        tree_table->clear();
        playout_count = 0;
        return make_root(board);
    }

    shared_ptr<SearchPartialResult> make_root(const Board &b)
    {
        bool mate_found;
        // ルートノードを生成
        root_node = MCTSBase::make_node(b, tree_table.get(), config.mate_1ply, mate_found, mate_move);
        SearchPartialResultEvalRequest *req = new SearchPartialResultEvalRequest();
        req->board.set(b);
        req->leaf = root_node;
        next_task = NextTask::ASSIGN_ROOT_EVAL;
        // ルートノードを評価するようリクエストを返す
        return shared_ptr<SearchPartialResult>(req);
    }

    shared_ptr<SearchPartialResult> assign_root_eval(const EvalResult *eval_result)
    {
        assert(eval_result);
        assert(prev_request);
        // 前回のsearch_partialの戻り値で返したルートノードの評価結果がeval_resultに入っているので、それをルートノードに代入する。
        auto leaf = prev_request->leaf;
        assign_eval_result_to_leaf(leaf, eval_result);

        next_task = NextTask::SEARCH_TREE;
        prev_request = nullptr;
        return nullptr;
    }

    shared_ptr<SearchPartialResult> assign_leaf_eval(const EvalResult *eval_result)
    {
        // 前回のsearch_partialの戻り値で返した葉ノードの評価結果がeval_resultに入っているので、それを葉ノードに代入する。
        assert(eval_result);
        assert(prev_request);
        auto leaf = prev_request->leaf;
        assign_eval_result_to_leaf(leaf, eval_result);
        backup_path(prev_request->tree_path, leaf->score);
        prev_request = nullptr;
        next_task = NextTask::SEARCH_TREE;
        return nullptr;
    }

    float assign_eval_result_to_leaf(TreeNode *leaf, const EvalResult *eval_result)
    {
        // valueは、logitのためtanhで勝=1,負=-1に変換する
        leaf->score = tanh(eval_result->value_logit);
        assert(leaf->n_legal_moves);
        // policyは、logitが入っているためsoftmax計算が必要
        // 省略: leaf->value_pに各指し手のpolicyを代入

        return leaf->score;
    }

    shared_ptr<SearchPartialResult> search_tree()
    {
        if (playout_count >= config.playout_limit)
        {
            // playoutは終わり。指し手を決定する。
            next_task = NextTask::CHOOSE_MOVE;
            return nullptr;
        }

        playout_count++;
        auto result = search_root();
        if (result)
        {
            next_task = NextTask::ASSIGN_LEAF_EVAL;
        }
        else
        {
            next_task = NextTask::SEARCH_TREE;
        }
        return result;
    }

    shared_ptr<SearchPartialResult> search_root()
    {
        vector<pair<TreeNode *, int>> path;
        return search_recursive(board, root_node, path);
    }

    shared_ptr<SearchPartialResult> search_recursive(Board &b, TreeNode *node, vector<pair<TreeNode *, int>> &path)
    {
        if (node->terminal())
        {
            backup_path(path, node->score);
            return nullptr;
        }

        int edge = MCTSBase::select_edge(node, config.c_puct);
        UndoInfo undo_info;
        b.do_move(static_cast<Move>(node->move_list[edge]), undo_info);
        path.push_back({node, edge});
        int child_node_idx = node->children[edge];
        node->value_n[edge]++;
        shared_ptr<SearchPartialResult> result;
        if (child_node_idx)
        {
            result = search_recursive(b, tree_table->at(child_node_idx), path);
        }
        else
        {
            // 子ノードがまだ生成されていない
            bool mate_found;
            Move mate_move;
            TreeNode *child_node = MCTSBase::make_node(b, tree_table.get(), config.mate_1ply, mate_found, mate_move);
            node->children[edge] = tree_table->get_index(child_node);
            if (!child_node->terminal())
            {
                // この場でバックアップできず、局面評価が必要
                SearchPartialResultEvalRequest *req = new SearchPartialResultEvalRequest();
                req->board.set(b);
                req->leaf = child_node;
                req->tree_path = path;
                result = shared_ptr<SearchPartialResult>(req);
            }
            else
            {
                // 終局または詰みが見つかった場合
                backup_path(path, child_node->score);
            }
        }

        b.undo_move(undo_info);
        return result;
    }

    void backup_path(const vector<pair<TreeNode *, int>> &path, float leaf_score)
    {
        float score = leaf_score;
        for (int i = int(path.size()) - 1; i >= 0; i--)
        {
            score = -score;
            path[i].first->value_w[path[i].second] += score;
        }
    }

    shared_ptr<SearchPartialResult> choose_move()
    {
        Move move = MOVE_PASS;
        float score = 0.0F;
        // 省略: 訪問回数に比例した確率で指し手を選択する
        next_task = NextTask::START_SEARCH;
        auto result = new SearchPartialResultMove();
        result->move = move;
        result->score = score;
        return shared_ptr<SearchPartialResult>(result);
    }
};

MCTSを用いた対局実験

ここでは、実装したMCTSと既存の教師あり学習による評価関数を用いて対局した結果を示します。対局相手はランダムプレイヤーです。

MCTS探索ノード数 勝ち 引き分け 負け
1 78 1 21
2 82 2 16
4 79 3 18
8 91 2 7
16 93 2 5
32 95 0 5
64 98 0 2

探索ノード数を増やすと徐々に強くなることが確認できました。AlphaZero式の学習を行うにはさらにpython側での実装が必要なため、次回解説します。