"""
第 14 章配套代码：最小 DQN 思想 demo（经验回放 + 目标网络）。
纯 numpy 实现一个单隐层 Q 网络，在前面同一个 gridworld 上学 Q(s,a)。
重点不是性能，而是把 DQN 的两个关键工程机制写清楚：
  (1) 经验回放 replay buffer：存 (s,a,r,s',done)，训练时随机小批量采样，打破时间相关性；
  (2) 目标网络 target net：用周期性冻结的参数算 TD 目标 r + gamma*max Q_target(s')，稳住训练。
Runnable with: numpy only.  python3 14_dqn_min.py
"""
import numpy as np

ROWS, COLS = 4, 4
START, GOAL, TRAP = (0, 0), (3, 3), (1, 3)
WALLS = {(1, 1), (2, 1)}
ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)]


def step(s, a):
    if s in (GOAL, TRAP):
        return s, 0.0, True
    dr, dc = ACTIONS[a]
    ns = (s[0] + dr, s[1] + dc)
    if ns[0] < 0 or ns[0] >= ROWS or ns[1] < 0 or ns[1] >= COLS or ns in WALLS:
        ns = s
    if ns == GOAL:
        return ns, 1.0, True
    if ns == TRAP:
        return ns, -1.0, True
    return ns, -0.04, False


def encode(s):
    """状态 one-hot 编码成长度 16 的向量（作为网络输入）。"""
    v = np.zeros(ROWS * COLS)
    v[s[0] * COLS + s[1]] = 1.0
    return v


class QNet:
    """单隐层 MLP: 16 -> H -> 4，输出每个动作的 Q 值。纯 numpy + 手写反传。"""
    def __init__(self, H=32, seed=0):
        rng = np.random.default_rng(seed)
        self.W1 = rng.normal(0, 0.1, (16, H))
        self.b1 = np.zeros(H)
        self.W2 = rng.normal(0, 0.1, (H, 4))
        self.b2 = np.zeros(4)

    def forward(self, X):
        self.X = X
        self.z1 = X @ self.W1 + self.b1
        self.h = np.maximum(0, self.z1)          # ReLU
        return self.h @ self.W2 + self.b2        # Q 值（线性输出）

    def params(self):
        return [self.W1, self.b1, self.W2, self.b2]

    def copy_from(self, other):
        self.W1, self.b1 = other.W1.copy(), other.b1.copy()
        self.W2, self.b2 = other.W2.copy(), other.b2.copy()

    def train_step(self, X, a_idx, target, lr=0.01):
        """对 Q(s,a) 做一步均方误差回归到 target，只在所选动作 a 上回传。"""
        Q = self.forward(X)
        pred = Q[np.arange(len(X)), a_idx]
        dQ = np.zeros_like(Q)
        dQ[np.arange(len(X)), a_idx] = (pred - target) / len(X)   # MSE 梯度
        dW2 = self.h.T @ dQ
        db2 = dQ.sum(0)
        dh = dQ @ self.W2.T
        dz1 = dh * (self.z1 > 0)
        dW1 = self.X.T @ dz1
        db1 = dz1.sum(0)
        self.W2 -= lr * dW2; self.b2 -= lr * db2
        self.W1 -= lr * dW1; self.b1 -= lr * db1
        return float(np.mean((pred - target) ** 2))


def train(episodes=1500, gamma=0.95, eps=0.2, batch=32,
          target_sync=50, seed=0):
    rng = np.random.default_rng(seed)
    online, target = QNet(seed=seed), QNet(seed=seed)
    target.copy_from(online)
    buffer = []                      # 经验回放缓冲
    states = [(r, c) for r in range(ROWS) for c in range(COLS)
              if (r, c) not in WALLS]
    for ep in range(episodes):
        s = START
        for _ in range(50):
            if rng.random() < eps:
                a = rng.integers(4)
            else:
                a = int(np.argmax(online.forward(encode(s)[None])[0]))
            ns, r, done = step(s, a)
            buffer.append((s, a, r, ns, done))
            if len(buffer) > 5000:
                buffer.pop(0)
            s = ns
            # 从回放缓冲随机采样一个 minibatch 训练
            if len(buffer) >= batch:
                idx = rng.choice(len(buffer), batch, replace=False)
                bs = [buffer[i] for i in idx]
                X = np.array([encode(t[0]) for t in bs])
                A = np.array([t[1] for t in bs])
                R = np.array([t[2] for t in bs])
                NS = np.array([encode(t[3]) for t in bs])
                D = np.array([t[4] for t in bs], dtype=float)
                # TD 目标用【目标网络】算，稳住训练
                Qn = target.forward(NS)
                td_target = R + gamma * np.max(Qn, axis=1) * (1 - D)
                online.train_step(X, A, td_target)
            if done:
                break
        if ep % target_sync == 0:        # 周期性同步目标网络
            target.copy_from(online)
    return online, states


def show(net, states):
    print("=== 最小 DQN (经验回放 + 目标网络) on 4x4 gridworld ===")
    arrow = {0: "↑", 1: "↓", 2: "←", 3: "→"}
    print("\n学到的贪婪策略:")
    for r in range(ROWS):
        row = []
        for c in range(COLS):
            if (r, c) in WALLS:
                row.append(" # ")
            elif (r, c) == GOAL:
                row.append(" G ")
            elif (r, c) == TRAP:
                row.append(" T ")
            else:
                a = int(np.argmax(net.forward(encode((r, c))[None])[0]))
                row.append(f" {arrow[a]} ")
        print("  " + "".join(row))
    print("\n机制要点: replay 打破样本相关性; target net 让 TD 目标不随每步抖动。")


if __name__ == "__main__":
    net, states = train()
    show(net, states)
