"""
第 12 章配套代码：表格型 Q-learning（gridworld）。
演示 Q-learning 更新 Q(s,a) <- Q + alpha*[r + gamma*max_a' Q(s',a') - Q]，
以及它如何在一个有墙、有陷阱、有目标的小网格里收敛出最优策略与价值。
Runnable with: numpy only.  python3 12_q_learning_gridworld.py
"""
import numpy as np

# 4x4 gridworld:
#  S . . .
#  . # . T   (# 墙不可进, T 陷阱 -1, G 目标 +1)
#  . # . .
#  . . . G
ROWS, COLS = 4, 4
START = (0, 0)
GOAL = (3, 3)
TRAP = (1, 3)
WALLS = {(1, 1), (2, 1)}
ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # 上 下 左 右
ARROW = {0: "↑", 1: "↓", 2: "←", 3: "→"}


def step(s, a):
    """环境转移：返回 (下一状态, 奖励, 是否终止)。"""
    if s == GOAL:
        return s, 0.0, True
    if s == TRAP:
        return s, 0.0, True
    dr, dc = ACTIONS[a]
    nr, nc = s[0] + dr, s[1] + dc
    ns = (nr, nc)
    # 撞墙或越界则原地不动
    if nr < 0 or nr >= ROWS or nc < 0 or nc >= COLS or ns in WALLS:
        ns = s
    if ns == GOAL:
        return ns, 1.0, True
    if ns == TRAP:
        return ns, -1.0, True
    return ns, -0.04, False  # 每步小负奖励，鼓励走捷径


def q_learning(episodes=4000, alpha=0.1, gamma=0.95, eps=0.1, seed=0):
    rng = np.random.default_rng(seed)
    Q = np.zeros((ROWS, COLS, 4))
    returns = []
    for ep in range(episodes):
        s = START
        total, done, steps = 0.0, False, 0
        while not done and steps < 100:
            # epsilon-greedy 探索
            if rng.random() < eps:
                a = rng.integers(4)
            else:
                a = int(np.argmax(Q[s[0], s[1]]))
            ns, r, done = step(s, a)
            # Q-learning 更新（off-policy: 目标用 max_a' Q）
            best_next = 0.0 if done else np.max(Q[ns[0], ns[1]])
            td_target = r + gamma * best_next
            Q[s[0], s[1], a] += alpha * (td_target - Q[s[0], s[1], a])
            s = ns
            total += r
            steps += 1
        returns.append(total)
    return Q, returns


def show(Q, returns):
    V = np.max(Q, axis=2)
    print("=== Q-learning on 4x4 gridworld (S=起点 G=+1 T=-1 #=墙) ===")
    print("\n状态价值 V(s) = max_a Q(s,a):")
    for r in range(ROWS):
        row = []
        for c in range(COLS):
            if (r, c) in WALLS:
                row.append("  ###  ")
            else:
                row.append(f"{V[r, c]:+6.2f} ")
        print("  " + "".join(row))
    print("\n贪婪策略（每格最优动作箭头）:")
    for r in range(ROWS):
        row = []
        for c in range(COLS):
            if (r, c) in WALLS:
                row.append(" # ")
            elif (r, c) == GOAL:
                row.append(" G ")
            elif (r, c) == TRAP:
                row.append(" T ")
            else:
                row.append(f" {ARROW[int(np.argmax(Q[r, c]))]} ")
        print("  " + "".join(row))
    # 收敛：取末 200 局平均回报
    last = np.mean(returns[-200:])
    first = np.mean(returns[:200])
    print(f"\n平均回报: 前200局={first:+.3f} -> 末200局={last:+.3f}  (学到了避陷阱、走捷径)")


if __name__ == "__main__":
    Q, returns = q_learning()
    show(Q, returns)
