機械学習周りのプログラミング中心。 イベント情報
ポケモンバトルAI本電子書籍通販中

【CodinGameオセロ】棋譜生成と学習のループ

準備が整ったので、AlphaZero方式の強化学習のコアを実装します。

ソースコードはこのバージョンです。

https://github.com/select766/codingame-othello/tree/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f

必要なモジュールは2つです。(1)自己対戦により棋譜を生成します。モデルは固定して探索を行います。(2)棋譜を用いた教師あり学習により、モデルを更新します。

棋譜の生成

棋譜の生成モジュールは、現在のモデルを入力とし、一定数の自己対局を行い、棋譜(局面、指し手、手番側の勝敗)を出力します。

モデルの推論はPythonからTensorflowを呼び出して行う一方、探索はC++で実装しています。暫定的な探索部ではプロセス間通信でこれらをつないでいましたが、1つのプロセス内にまとめることで通常の関数呼び出しを行えるようにし、実装の見通しをよくします。C++で実装したモジュールを、Pythonから呼び出せる形式にビルドします。pybind11を用いることで、C++の関数・クラスをPythonモジュールとしてビルドできます。特に、numpy配列を関数の引数として用いることができるため、Tensorflowとのデータのやり取りが容易になります。

Pythonから呼び出すC++の機能は、以下の関数に集約しました。

  • int init_playout(const string &record_path, int parallel, int playout_limit): 自己対戦機構を初期化する。保存する棋譜ファイルのパスなどを指定する。
  • void proceed_playout(py::array_t<float> batch_board_repr, py::array_t<float> batch_policy_logits, py::array_t<float> batch_value_logit): DNNの評価結果を与えて自己対局を進める。
    • py::array_t<float> batch_policy_logits, py::array_t<float> batch_value_logit: 前回のproceed_playoutの結果として得られた、評価すべき局面をDNNで評価した結果をnumpy配列形式で与える。
    • py::array_t<float> batch_board_repr: numpy配列のプレースホルダで、次にDNNで評価すべき局面をC++側から書きこむ。
  • void end_playout(): 自己対戦機構を終了し、棋譜ファイルを閉じる。
  • int games_completed(): 終了した対局の数を取得する。自己対局を終了するタイミングを測るための機能。

pybind11を用いてPython側に見せる関数を提供するコードは以下のようになります。

https://github.com/select766/codingame-othello/blob/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f/src/lib_pybind11.cpp

Python側は以下のコードになります。

https://github.com/select766/codingame-othello/blob/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f/othello_train/playout_v1.py

Python側では、モデルを読み込んで初期化を終えた後は(1) proceed_playoutを呼び出し、次に評価すべき局面を受け取る、(2)Tensorflowで評価を行う、というループを行うだけです。pybind11で作られたモジュールがothello_train_cppという名前で読めるように(Makefileにより)配置されているので、Pythonで書かれたモジュールと同様にimportできます。

from othello_train import othello_train_cpp

C++側では、proceed_playoutが呼ばれた時に複数の対局を順番に進めます。1つの対局に対応するSinglePlayoutクラスは以下のようになっています。

class PlayoutBuffer
{
public:
    float *board_repr;    // Playoutが評価を求めたい盤面表現をこのアドレスに書き込む
    const float *policy_logits; // 前回評価を求められた局面の評価結果をPlayoutに渡す
    const float *value_logit;   // 前回評価を求められた局面の評価結果をPlayoutに渡す
};

class SinglePlayout
{
    Board board;
    FeatureExtractor extractor;
    vector<MoveRecord> records;
    shared_ptr<SearchMCTSTrain::SearchPartialResultEvalRequest> last_eval_request;
    int _games_completed;

public:
    shared_ptr<ofstream> fout;
    SearchMCTSTrain engine;

    SinglePlayout(shared_ptr<ofstream> fout, SearchMCTSTrain::SearchMCTSConfig mcts_config) : fout(fout), engine(mcts_config), extractor(), _games_completed(0)
    {
        board.set_hirate();
        engine.board.set(board);
    }

    int games_completed() const
    {
        return _games_completed;
    }

    void proceed(PlayoutBuffer &playout_buffer)
    {
        SearchMCTSTrain::EvalResult eval_result;
        if (playout_buffer.policy_logits)
        {
            memcpy(eval_result.policy_logits, playout_buffer.policy_logits, sizeof(eval_result.policy_logits));
        }
        if (playout_buffer.value_logit)
        {
            memcpy(&eval_result.value_logit, playout_buffer.value_logit, sizeof(eval_result.value_logit));
        }
        eval_result.request = last_eval_request;
        while (true)
        {
            auto search_partial_result = engine.search_partial(&eval_result);
            auto result_move = dynamic_pointer_cast<SearchMCTSTrain::SearchPartialResultMove>(search_partial_result);
            if (result_move)
            {
                // 指し手を進める
                proceed_game(result_move->move);
            }
            auto result_eval = dynamic_pointer_cast<SearchMCTSTrain::SearchPartialResultEvalRequest>(search_partial_result);
            if (result_eval)
            {
                // 評価が必要
                DNNInputFeature feat = extractor.extract(result_eval->board);
                memcpy(playout_buffer.board_repr, feat.board_repr, sizeof(feat.board_repr));
                last_eval_request = result_eval;
                return;
            }
        }
    }

private:
    void do_move_with_record(Move move)
    {
        // boardを進めるとともに指し手を記録
        BoardPlane lm;
        board.legal_moves_bb(lm);
        auto n_legal_moves = __builtin_popcountll(lm);
        MoveRecord record;
        record.move = static_cast<decltype(record.move)>(move);
        record.planes[0] = board.plane(0);
        record.planes[1] = board.plane(1);
        record.turn = static_cast<decltype(record.turn)>(board.turn());
        record.n_legal_moves = static_cast<decltype(record.turn)>(n_legal_moves);
        memset(record.pad, 0, sizeof(record.pad));

        records.push_back(record);

        UndoInfo undo_info;
        board.do_move(move, undo_info);
    }

    void flush_record_with_game_result()
    {
        // gameoverの時に呼び出す。recordsにゲームの結果を書きこんだうえでファイルに出力する。
        
        int8_t stone_diff_black = static_cast<int8_t>(board.piece_num(BLACK) - board.piece_num(WHITE));
        for (auto &record : records)
        {
            record.game_result = record.turn == BLACK ? stone_diff_black : -stone_diff_black;
        }
        
        fout->write((char*)&records[0], records.size() * sizeof(MoveRecord));
        records.clear();
        _games_completed++;
    }

    void proceed_game(Move move)
    {
        // 指定された指し手でゲームを進め、次に指し手選択が必要な状態まで進行する。最新の局面をengineにセットする。
        do_move_with_record(move);

        while (true)
        {
            if (board.is_gameover())
            {
                flush_record_with_game_result();
                board.set_hirate();
                engine.newgame();
                engine.board.set(board);
            }
            
            vector<Move> move_list;
            board.legal_moves(move_list);
            if (move_list.empty())
            {
                do_move_with_record(MOVE_PASS);
            }
            else if (move_list.size() == 1)
            {
                do_move_with_record(move_list[0]);
            }
            else
            {
                break;
            }
        }

        engine.board.set(board);
        return;
    }
};

行数は多いですが、やっていることは単純です。

  • DNN評価結果を受け取り、MCTSのゲーム木を更新する
  • 探索回数が一定値に達した場合
    • 指し手を決定し、局面を一手進める
    • 局面を進めた結果、終局した場合
      • 棋譜をファイルに書き込む
      • 新しい対局を開始する
  • ゲーム木の探索を再開/新しい局面で開始し、評価すべき局面を返す

これらの分岐を処理することで、インターフェースとしては前回返した局面に対する評価結果を受け取り、次に評価すべき局面を返すという形式にまとまります。さらに、ParallelPlayoutクラスではSinglePlayoutオブジェクトをバッチサイズ(例えば256)個生成し、順番に呼び出して結果をまとめることで、Tensorflow側でミニバッチ処理できるnumpy配列を得る実装になっています。

このコードに表れていない工夫点として、局面をキーとし、DNNの評価結果をキャッシュしています。序盤は別々の対局で同じ局面が出現しますし、1つの対局の中で、ある局面に対する手を決める過程で評価した局面の一部が、次の局面の時にも再度出現します。この機構により、DNNの実行回数が20%程度低減しました。

モデルの更新

モデルの更新モジュールは、現時点のモデルとそれを用いて生成した棋譜を入力とし、教師あり学習を行って、更新されたモデルを出力します。

https://github.com/select766/codingame-othello/blob/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f/othello_train/rl_train_v1.py

処理は通常の教師あり学習と何ら変わらないため、特別な実装はありません。

ループでつなげる

最後に、棋譜の生成とモデルの更新を交互に行うためのループを実装します。各プロセスを、引数を変えながら起動することで実現します。

https://github.com/select766/codingame-othello/blob/826d5aa02298d07ca088969bdd8f10c3ca2e8c3f/othello_train/rl_loop.py

import argparse
import subprocess
from pathlib import Path
from typing import Optional


def check_call(args, skip_if_exists: Optional[Path] = None):
    if skip_if_exists is not None and skip_if_exists.exists():
        print("#skip: " + " ".join(args))
        return
    print(" ".join(args))
    subprocess.check_call(args)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("work_dir")
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--games", type=int, default=10000)
    args = parser.parse_args()

    work_dir = Path(args.work_dir)
    records_dir = work_dir / "records"
    records_dir.mkdir(parents=True, exist_ok=True)

    for epoch in range(args.epoch):
        if epoch == 0:
            check_call(["python", "-m", "othello_train.make_empty_model_v1",
                       f"{work_dir}/cp_{epoch}/cp"], work_dir / f"cp_{epoch}")
            check_call(["python", "-m", "othello_train.checkpoint_to_savedmodel_v1",
                        f"{work_dir}/cp_{epoch}/cp", f"{work_dir}/sm_{epoch}"], work_dir / f"sm_{epoch}")
        check_call(["python", "-m", "othello_train.playout_v1",
                   f"{work_dir}/sm_{epoch}", f"{records_dir}/records_{epoch}.bin", "--games", f"{args.games}"], records_dir / f"records_{epoch}.bin")
        check_call(["python", "-m", "othello_train.rl_train_v1", f"{work_dir}/cp_{epoch}/cp", f"{work_dir}/cp_{epoch+1}/cp",
                   f"{records_dir}/records_{epoch}.bin"], work_dir / f"cp_{epoch+1}")
        check_call(["python", "-m", "othello_train.checkpoint_to_savedmodel_v1",
                   f"{work_dir}/cp_{epoch+1}/cp", f"{work_dir}/sm_{epoch+1}"], work_dir / f"sm_{epoch+1}")


if __name__ == "__main__":
    main()

checkpoint_to_savedmodel_v1は、学習に用いるモデル形式であるcheckpoint形式から、推論に便利な(モデル構造を内包した)savedmodel形式に変換するコードです。学習は1日以上かかるため、途中で中断・再開できるように工夫しています。check_call関数にskip_if_existsという引数があり、ここに指定したパスのファイルが既に存在する(そのコマンドの成果物が既に存在している)場合はそのコマンドの実行をスキップするという処理になっています。ただし、棋譜生成は中断すると、その時点までの中途半端な棋譜ファイルが残ってしまうため、再開前に手動で削除する必要があります。

このスクリプトを実行することで、AlphaZero方式の強化学習が実現できます。考え方は単純なものの、実装はかなり長くなりました。ミニバッチでの処理が可能となるよう、探索を並行処理可能にする機構が特に長くなりました。

このスクリプトを用いてモデルを強化学習しました。1手の思考に64局面の評価を行い、1epochあたりの自己対局数は10000としました。

各epochのモデルを(棋譜生成ではなく本番用に近い)対局エンジンに読み込ませ、ランダムプレイヤーと強さを比較しました。MCTSの探索ノード数は16です。

epoch 勝ち 引き分け 負け
0 58 5 37
1 83 1 16
2 99 0 1
3 99 0 1

epoch=0は、ランダムに初期化したモデルです。これがランダムより良いのは、終盤ではゲーム木の末端に勝敗が決定した局面が出現するため、評価関数の内容によらず勝てる手を指せるためであると考えられます。epoch=1は、ランダムなモデルで指した手と勝敗を学習した状態です。学習した手の品質は非常に低いと思われますが、勝敗については局面の良さと相関があると考えられます。そのため評価関数として成立し、強くなったと考えられます。epoch=2以降は意味のある学習データが利用できており、さらに強くなりました。これ以上はランダムプレイヤーを相手に測定することが難しいというところまで、強化学習により強くすることに成功しました。

ここまでの実装で、AlphaZero方式の強化学習が実現できました。しかし、モデルの評価でTensorflowを利用しているためCodinGameに投稿することができません。今後はCodinGameへ投稿できるエンジンの実装に入っていきます。