"""
第 06 章配套代码：迷你 skip-gram word2vec（含负采样思想的简化版）
Runnable with: numpy only.  python3 06_word2vec.py

在一个玩具语料上训练词向量，演示 (1) skip-gram 目标 (2) 线性类比 king-man+woman≈queen 的几何。
语料是人造的，刻意让 king/queen/man/woman 形成平行四边形关系。
"""
import numpy as np

rng = np.random.default_rng(0)

# 玩具语料：每句是一组共现的词（模拟 skip-gram 的中心词-上下文对）
sentences = [
    "king man royal palace", "queen woman royal palace",
    "man boy human person", "woman girl human person",
    "king queen royal crown", "prince king boy royal",
    "princess queen girl royal", "man king strong",
    "woman queen strong", "boy prince young", "girl princess young",
]
vocab = sorted(set(" ".join(sentences).split()))
w2i = {w: i for i, w in enumerate(vocab)}
V = len(vocab)
D = 8  # 向量维度

# 生成 (中心词, 上下文词) 训练对
pairs = []
for s in sentences:
    ws = s.split()
    for i, c in enumerate(ws):
        for j, o in enumerate(ws):
            if i != j:
                pairs.append((w2i[c], w2i[o]))
pairs = np.array(pairs)

# skip-gram: 中心词向量 Win，上下文向量 Wout
Win = rng.normal(0, 0.1, (V, D))
Wout = rng.normal(0, 0.1, (V, D))


def softmax(x):
    e = np.exp(x - x.max()); return e / e.sum()


lr = 0.01
for epoch in range(600):
    rng.shuffle(pairs)
    loss = 0
    for c, o in pairs:
        h = Win[c]                       # 中心词向量
        scores = Wout @ h                # 对每个词的得分
        p = softmax(scores)              # 预测上下文分布
        loss -= np.log(p[o] + 1e-9)
        # 梯度（交叉熵 + softmax）
        dscores = p.copy(); dscores[o] -= 1
        Win[c] -= lr * (Wout.T @ dscores)
        Wout -= lr * np.outer(dscores, h)
    if epoch % 160 == 0:
        print(f"epoch {epoch:3d}  loss={loss/len(pairs):.4f}")


def vec(w): return Win[w2i[w]]
def cos(a, b): return a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9)


def nearest(v, exclude=()):
    sims = [(w, cos(v, Win[i])) for w, i in w2i.items() if w not in exclude]
    return sorted(sims, key=lambda t: -t[1])[:3]


# 在玩具语料上，可靠演示的是"语义相关词聚在一起"（近邻结构）。
# 完整的 king-man+woman≈queen 平行四边形几何需要真实大语料（数十亿词）才稳定，
# 见正文说明；这里展示机制：相关词的向量彼此靠近。
print("\n近邻结构（语义相关词的向量彼此靠近）:")
for w in ["king", "woman", "royal"]:
    print(f"  与 '{w}' 最近: {[(x, round(s,2)) for x, s in nearest(vec(w), exclude={w})]}")

print("\n类比向量 king - man + woman 的近邻:")
analogy = vec("king") - vec("man") + vec("woman")
print("  ", [(x, round(s, 2)) for x, s in nearest(analogy, exclude={"king", "man", "woman"})])
print("  （玩具语料下结果有噪声；真实 word2vec 在大语料上此处稳定指向 'queen'）")
