"""
第 04 章配套代码：(1) 演示 RNN 的梯度消失/爆炸  (2) 一个 LSTM cell 的前向
Runnable with: numpy only.  python3 04_lstm_cell.py
"""
import numpy as np


def vanishing_gradient_demo():
    """简化的 BPTT 梯度连乘：grad ~ prod_t (W * sigma'(.))。
    若有效乘子 |lambda| < 1，梯度随时间步指数衰减 -> 消失；>1 -> 爆炸。
    """
    print("=== 梯度消失/爆炸（BPTT 连乘）===")
    for lam, name in [(0.6, "衰减(消失)"), (1.0, "临界"), (1.3, "增长(爆炸)")]:
        for T in [5, 20, 50]:
            grad = lam ** T
            print(f"  乘子={lam} [{name}]  T={T:2d} 步后梯度 ~ {grad:.3e}")
        print()


def sigmoid(x): return 1 / (1 + np.exp(-x))


def lstm_step(x, h_prev, c_prev, Wf, Wi, Wc, Wo, bf, bi, bc, bo):
    """单个时间步的 LSTM。z = [h_prev; x] 拼接。
    遗忘门 f, 输入门 i, 候选 g, 输出门 o:
      f = σ(Wf z + bf)      i = σ(Wi z + bi)
      g = tanh(Wc z + bc)   o = σ(Wo z + bo)
      c = f ⊙ c_prev + i ⊙ g          (Constant Error Carousel: 关键的加法更新)
      h = o ⊙ tanh(c)
    """
    z = np.concatenate([h_prev, x])
    f = sigmoid(Wf @ z + bf)
    i = sigmoid(Wi @ z + bi)
    g = np.tanh(Wc @ z + bc)
    o = sigmoid(Wo @ z + bo)
    c = f * c_prev + i * g       # 细胞状态：门控的加法传递 => 误差可几乎无衰减地流过
    h = o * np.tanh(c)
    return h, c, dict(f=f, i=i, g=g, o=o)


if __name__ == "__main__":
    vanishing_gradient_demo()

    print("=== LSTM cell 前向（hidden=3, input=2）===")
    rng = np.random.default_rng(0)
    H, D = 3, 2
    Wf, Wi, Wc, Wo = [rng.normal(0, .3, (H, H + D)) for _ in range(4)]
    bf = np.ones(H)          # 遗忘门偏置初始化为正：默认"记住"
    bi, bc, bo = [np.zeros(H) for _ in range(3)]
    h, c = np.zeros(H), np.zeros(H)
    seq = [np.array([1., 0.]), np.array([0., 1.]), np.array([1., 1.])]
    for t, x in enumerate(seq):
        h, c, gates = lstm_step(x, h, c, Wf, Wi, Wc, Wo, bf, bi, bc, bo)
        print(f"  t={t} 遗忘门均值={gates['f'].mean():.2f} 输入门均值={gates['i'].mean():.2f} "
              f"c={np.round(c,3)} h={np.round(h,3)}")
