"""
第 07/08 章配套代码：
(1) Bahdanau 风格 additive attention（2014）
(2) Transformer 的 scaled dot-product attention + multi-head（2017）
Runnable with: numpy only.  python3 07_attention.py
"""
import numpy as np
rng = np.random.default_rng(0)


def softmax(x, axis=-1):
    e = np.exp(x - x.max(axis=axis, keepdims=True))
    return e / e.sum(axis=axis, keepdims=True)


def additive_attention(query, keys, values, Wq, Wk, v):
    """Bahdanau 2014: score(q, k) = v^T tanh(Wq q + Wk k)。
    query: (dq,)  keys/values: (T, dk)/(T, dv)
    """
    scores = np.array([v @ np.tanh(Wq @ query + Wk @ k) for k in keys])  # (T,)
    alpha = softmax(scores)                  # 注意力权重（对齐分布）
    context = alpha @ values                 # 加权求和得到 context 向量
    return context, alpha


def scaled_dot_product_attention(Q, K, V, mask=None):
    """Transformer 2017: Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) V
    Q:(Lq,d) K:(Lk,d) V:(Lk,dv)
    """
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)          # (Lq, Lk) 缩放点积
    if mask is not None:
        scores = np.where(mask, scores, -1e9)  # 因果掩码：屏蔽未来位置
    A = softmax(scores, axis=-1)             # 每个 query 在所有 key 上的注意力
    return A @ V, A


def multi_head_attention(X, d_model=8, n_heads=2):
    """把 d_model 拆成 n_heads 个子空间，各自做注意力，再拼接。"""
    L = X.shape[0]; d_head = d_model // n_heads
    out = np.zeros((L, d_model)); attns = []
    for h in range(n_heads):
        Wq, Wk, Wv = [rng.normal(0, .5, (d_model, d_head)) for _ in range(3)]
        Q, K, V = X @ Wq, X @ Wk, X @ Wv
        o, A = scaled_dot_product_attention(Q, K, V)
        out[:, h * d_head:(h + 1) * d_head] = o
        attns.append(A)
    return out, attns


if __name__ == "__main__":
    print("=== Bahdanau additive attention（2014）===")
    T, dk = 4, 5
    keys = rng.normal(0, 1, (T, dk)); values = keys.copy()
    query = keys[2] + rng.normal(0, .1, dk)   # query 接近第 3 个 key
    Wq = np.eye(dk); Wk = np.eye(dk); v = np.ones(dk)
    ctx, alpha = additive_attention(query, keys, values, Wq, Wk, v)
    print("  注意力权重 alpha =", np.round(alpha, 3), "(应在最接近的 key 上最大)")

    print("\n=== Scaled dot-product attention（2017）===")
    L, d = 4, 6
    X = rng.normal(0, 1, (L, d))
    Q = K = V = X
    o, A = scaled_dot_product_attention(Q, K, V)
    print("  注意力矩阵 A (每行和为1):\n", np.round(A, 2))

    print("\n=== 因果掩码（GPT 式自回归，不能看未来）===")
    mask = np.tril(np.ones((L, L))).astype(bool)
    o, A = scaled_dot_product_attention(Q, K, V, mask=mask)
    print("  下三角注意力(位置 i 只能看 <=i):\n", np.round(A, 2))

    print("\n=== Multi-head attention ===")
    out, attns = multi_head_attention(X, d_model=6, n_heads=2)
    print("  输出形状:", out.shape, " head 数:", len(attns))
