📘 前言

《The Annotated Transformer》是哈佛大学 NLP 团队对经典论文《Attention is All You Need》的逐行讲解式实现,
使用 PyTorch 编写,目的是帮助开发者深入理解 Transformer 模型的内部机制与实现细节。本文将对这篇文章进行系统性的解读
适合对深度学习与自然语言处理(NLP)感兴趣的技术人员与研究者阅读。


🧠 一、项目背景与目标

Transformer 是一种完全基于注意力机制的神经网络架构,由 Google 于 2017 年提出。它彻底摆脱了 RNN 和 CNN 对序列处理的限制,
实现了高效并行和优秀的长期依赖建模能力。

《The Annotated Transformer》的目标是:

  • 用简洁、明了、逐行注释的 PyTorch 代码复现 Transformer
  • 帮助读者理解 Transformer 的关键模块和实现原理
  • 作为学习、研究和二次开发的良好起点

🧱 二、Transformer 模型结构概览

Transformer 的整体结构分为:

  1. 编码器(Encoder)
  2. 解码器(Decoder)
  3. 多头注意力机制(Multi-Head Attention)
  4. 位置编码(Positional Encoding)
  5. 前馈神经网络(Feed-Forward Network)
  6. 残差连接 + 层归一化(Residual Connection + LayerNorm)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传


🔍 三、核心模块解析与代码说明

1️⃣ 模块克隆工具 clones

该函数用于生成多个重复子层结构(例如 6 层编码器):

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

2️⃣ 位置编码(Positional Encoding)

由于 Transformer 不具备序列结构,需要人为加入位置感知:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

PyTorch 实现:

class PositionalEncoding(nn.Module):
    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

3️⃣ 注意力机制(Scaled Dot-Product Attention)

计算 Query 与 Key 的点积、归一化、Softmax 后加权 Value:

def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    p_attn = F.softmax(scores, dim=-1)
    return torch.matmul(p_attn, value), p_attn

4️⃣ 多头注意力机制(Multi-Head Attention)

多头注意力将注意力过程并行化:

class MultiHeadedAttention(nn.Module):
    def forward(self, query, key, value, mask=None):
        # 分头,注意力计算,拼接输出

5️⃣ 子层连接(残差 + 层归一化)

这是 Transformer 稳定训练的关键技巧之一。

class SublayerConnection(nn.Module):
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

这种方式被称为 Post-Norm(先归一化再残差连接)。


6️⃣ 编码器与解码器层结构

编码器层 EncoderLayer

class EncoderLayer(nn.Module):
    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

解码器层 DecoderLayer

class DecoderLayer(nn.Module):
    def forward(self, x, memory, src_mask, tgt_mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))
        return self.sublayer[2](x, self.feed_forward)

🔁 四、训练流程与优化器

损失函数:交叉熵 + 标签平滑

class LabelSmoothing(nn.Module):
    def forward(self, x, target):
        ... # 使用平滑标签而非 one-hot

学习率调度器(NoamOpt)

Transformer 使用了论文中自定义的优化器调度器:

class NoamOpt:
    def step(self):
        lrate = factor * d_model ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))

🧪 五、推理与可视化

推理过程(Greedy Decode)

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    ...  # 每步生成一个词直到终止符

可视化注意力权重

项目支持使用 matplotlib 绘制每一层每个头的注意力热力图,帮助观察模型是如何捕捉输入中不同词之间关系的。


🛠 六、简单使用案例(简化)

# 输入句子:I am a student
src = torch.LongTensor([[0, 5, 8, 14, 2]])  # 模拟词表索引
model.eval()
output = greedy_decode(model, src, src_mask, max_len=10, start_symbol=1)
print("模型输出:", output)

🛠 七、注意力可视化(Attention Visualization )

  tgt_sent = trans.split()
def draw(data, x, y, ax):
   seaborn.heatmap(data, 
                   xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0, 
                   cbar=False, ax=ax)
   
for layer in range(1, 6, 2):
   fig, axs = plt.subplots(1,4, figsize=(20, 10))
   print("Encoder Layer", layer+1)
   for h in range(4):
       draw(model.encoder.layers[layer].self_attn.attn[0, h].data, 
           sent, sent if h ==0 else [], ax=axs[h])
   plt.show()
   
for layer in range(1, 6, 2):
   fig, axs = plt.subplots(1,4, figsize=(20, 10))
   print("Decoder Self Layer", layer+1)
   for h in range(4):
       draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(tgt_sent)], 
           tgt_sent, tgt_sent if h ==0 else [], ax=axs[h])
   plt.show()
   print("Decoder Src Layer", layer+1)
   fig, axs = plt.subplots(1,4, figsize=(20, 10))
   for h in range(4):
       draw(model.decoder.layers[layer].self_attn.attn[0, h].data[:len(tgt_sent), :len(sent)], 
           sent, tgt_sent if h ==0 else [], ax=axs[h])
   plt.show()



📝 八、总结与启示

模块 作用 特点
编码器 提取输入特征 多头注意力 + 前馈网络
解码器 逐步生成输出 Masked Attention + 编码器联动
注意力机制 聚焦关键信息 Query-Key-Value 模型
残差连接 防梯度消失 提高深层网络稳定性
位置编码 加入顺序信息 sin/cos 周期函数

理解 Transformer 的三个关键:

  1. 一切皆注意力(All Attention)
  2. 模块化、组合式设计
  3. 高度并行训练能力

🔗 项目资源与延伸阅读

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐