"""
第 15 章配套代码：MCTS + PUCT 最小实现示意（AlphaZero 风格）。
用一个最简单的确定性博弈（"抢到 21"：两人轮流报 1~3，谁报到 21 谁赢）演示
AlphaZero 的搜索内核——PUCT 选择 + 用一个（这里用先验均匀的）策略/价值"网络"
引导的蒙特卡洛树搜索。重点是 PUCT 公式与"选择→扩展→评估→回传"四步。
Runnable with: numpy only.  python3 15_mcts_puct.py
"""
import numpy as np

TARGET = 21
MOVES = [1, 2, 3]


def legal(total):
    return [m for m in MOVES if total + m <= TARGET]


class Node:
    def __init__(self, total, to_move):
        self.total = total          # 当前累计点数
        self.to_move = to_move      # 该谁走（+1 / -1）
        self.children = {}          # move -> Node
        self.N = {m: 0 for m in legal(total)}   # 访问次数
        self.W = {m: 0.0 for m in legal(total)} # 累计价值
        self.P = {}                 # 先验概率（策略网络输出，这里均匀）
        ms = legal(total)
        for m in ms:
            self.P[m] = 1.0 / len(ms) if ms else 0.0

    def is_terminal(self):
        return self.total == TARGET or not legal(self.total)


def puct_select(node, c=1.5):
    """PUCT: a* = argmax  Q(s,a) + c * P(s,a) * sqrt(ΣN) / (1+N(s,a))。"""
    total_N = sum(node.N.values())
    best, best_score = None, -1e9
    for m in node.N:
        q = node.W[m] / node.N[m] if node.N[m] > 0 else 0.0
        u = c * node.P[m] * np.sqrt(total_N + 1) / (1 + node.N[m])
        score = q + u
        if score > best_score:
            best_score, best = score, m
    return best


_solve_cache = {}


def solve(total):
    """完美评估器（代替价值网络）：返回'当前待走方'在最优对弈下的价值 ±1。
    这扮演 AlphaZero 中 value 网络的角色——给搜索一个准确的叶子评估。"""
    if total in _solve_cache:
        return _solve_cache[total]
    ms = legal(total)
    if not ms:                       # 走不了 => 对手刚报到 21 => 当前方输
        _solve_cache[total] = -1.0
        return -1.0
    best = -1.0
    for m in ms:
        if total + m == TARGET:      # 当前方报到 21，直接赢
            best = 1.0
            break
        v = -solve(total + m)        # 轮到对手，对手价值取负
        best = max(best, v)
    _solve_cache[total] = best
    return best


def rollout_value(total, to_move, rng):
    """叶子评估：调用完美评估器（demo 里把 value 网络理想化为精确解）。"""
    return solve(total)


def mcts(root_total, root_mover, sims=2000, seed=0):
    rng = np.random.default_rng(seed)
    root = Node(root_total, root_mover)
    for _ in range(sims):
        node = root
        path = []
        # 1) 选择：沿 PUCT 下行到未扩展或终止
        while not node.is_terminal() and all(node.N[m] > 0 for m in node.N):
            m = puct_select(node)
            path.append((node, m))
            node = node.children[m]
        # 2) 扩展 + 3) 评估
        if not node.is_terminal():
            m = puct_select(node)
            path.append((node, m))
            child = Node(node.total + m, -node.to_move)
            node.children[m] = child
            value = -rollout_value(child.total, child.to_move, rng)  # 站在父节点视角
        else:
            value = 0.0
        # 4) 回传（沿路径交替取负，因为零和博弈视角每层翻转）
        for parent, mv in reversed(path):
            parent.N[mv] += 1
            parent.W[mv] += value
            value = -value
    return root


def report():
    print("=== MCTS + PUCT 最小示意（'抢21'博弈，先手必胜）===")
    print("规则: 轮流报 1~3，累加，谁报到 21 谁赢。")
    print("已知最优: 抢到 17/13/9/5/1 的人必胜（先手报1即锁胜）。\n")
    # 从空局（先手待走）跑 MCTS，看它是否找到 '先报 1'
    root = mcts(0, +1, sims=4000)
    print("根节点（先手第一步）各动作的访问次数 N 与均值价值 Q:")
    for m in sorted(root.N):
        q = root.W[m] / root.N[m] if root.N[m] else 0.0
        print(f"  报 {m}: N={root.N[m]:5d}  Q={q:+.3f}")
    best = max(root.N, key=lambda m: root.N[m])
    print(f"\nMCTS 选择的第一步: 报 {best}  (最优解=报1，使对手面对 20 必败)")
    print("机制: PUCT = Q + c·P·√ΣN/(1+N)，前期靠先验P探索、后期靠Q利用；")
    print("AlphaZero 把这里的随机 rollout 换成 value 网络，把均匀 P 换成 policy 网络。")


if __name__ == "__main__":
    report()
