|
@@ -0,0 +1,249 @@
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
|
+import math
|
|
|
|
|
+import copy
|
|
|
|
|
+
|
|
|
|
|
+class MultiHeadAttention(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 多头注意力机制模块
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, d_model, num_heads):
|
|
|
|
|
+ super(MultiHeadAttention, self).__init__()
|
|
|
|
|
+ assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
|
|
|
|
|
+
|
|
|
|
|
+ self.d_model = d_model
|
|
|
|
|
+ self.num_heads = num_heads
|
|
|
|
|
+ self.d_k = d_model // num_heads
|
|
|
|
|
+
|
|
|
|
|
+ # 定义 Q, K, V 和输出的线性变换层
|
|
|
|
|
+ self.W_q = nn.Linear(d_model, d_model)
|
|
|
|
|
+ self.W_k = nn.Linear(d_model, d_model)
|
|
|
|
|
+ self.W_v = nn.Linear(d_model, d_model)
|
|
|
|
|
+ self.W_o = nn.Linear(d_model, d_model)
|
|
|
|
|
+
|
|
|
|
|
+ def scaled_dot_product_attention(self, Q, K, V, mask=None):
|
|
|
|
|
+ # 1. 计算注意力得分 (QK^T)
|
|
|
|
|
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 应用掩码 (如果提供)
|
|
|
|
|
+ if mask is not None:
|
|
|
|
|
+ # 将掩码中为 0 的位置设置为一个非常小的负数,这样 softmax 后会接近 0
|
|
|
|
|
+ attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 计算注意力权重 (Softmax)
|
|
|
|
|
+ attn_probs = torch.softmax(attn_scores, dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 加权求和 (权重 * V)
|
|
|
|
|
+ output = torch.matmul(attn_probs, V)
|
|
|
|
|
+ return output
|
|
|
|
|
+
|
|
|
|
|
+ def split_heads(self, x):
|
|
|
|
|
+ # 将输入 x 的形状从 (batch_size, seq_length, d_model)
|
|
|
|
|
+ # 变换为 (batch_size, num_heads, seq_length, d_k)
|
|
|
|
|
+ batch_size, seq_length, d_model = x.size()
|
|
|
|
|
+ return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
|
|
|
|
|
+
|
|
|
|
|
+ def combine_heads(self, x):
|
|
|
|
|
+ # 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)
|
|
|
|
|
+ # 变回 (batch_size, seq_length, d_model)
|
|
|
|
|
+ batch_size, num_heads, seq_length, d_k = x.size()
|
|
|
|
|
+ return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, Q, K, V, mask=None):
|
|
|
|
|
+ # 1. 对 Q, K, V 进行线性变换
|
|
|
|
|
+ Q = self.split_heads(self.W_q(Q))
|
|
|
|
|
+ K = self.split_heads(self.W_k(K))
|
|
|
|
|
+ V = self.split_heads(self.W_v(V))
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 计算缩放点积注意力
|
|
|
|
|
+ attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 合并多头输出并进行最终的线性变换
|
|
|
|
|
+ output = self.W_o(self.combine_heads(attn_output))
|
|
|
|
|
+ return output
|
|
|
|
|
+
|
|
|
|
|
+class PositionWiseFeedForward(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 位置前馈网络模块
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, d_model, d_ff, dropout=0.1):
|
|
|
|
|
+ super(PositionWiseFeedForward, self).__init__()
|
|
|
|
|
+ self.linear1 = nn.Linear(d_model, d_ff)
|
|
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
|
|
+ self.linear2 = nn.Linear(d_ff, d_model)
|
|
|
|
|
+ self.relu = nn.ReLU()
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ # x 形状: (batch_size, seq_len, d_model)
|
|
|
|
|
+ x = self.linear1(x)
|
|
|
|
|
+ x = self.relu(x)
|
|
|
|
|
+ x = self.dropout(x)
|
|
|
|
|
+ x = self.linear2(x)
|
|
|
|
|
+ # 最终输出形状: (batch_size, seq_len, d_model)
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+class PositionalEncoding(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 为输入序列的词嵌入向量添加位置编码。
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.dropout = nn.Dropout(p=dropout)
|
|
|
|
|
+
|
|
|
|
|
+ # 创建一个足够长的位置编码矩阵
|
|
|
|
|
+ position = torch.arange(max_len).unsqueeze(1)
|
|
|
|
|
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
|
|
|
+
|
|
|
|
|
+ # pe (positional encoding) 的大小为 (max_len, d_model)
|
|
|
|
|
+ pe = torch.zeros(max_len, d_model)
|
|
|
|
|
+
|
|
|
|
|
+ # 偶数维度使用 sin, 奇数维度使用 cos
|
|
|
|
|
+ pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
|
|
+ pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
|
+
|
|
|
|
|
+ # 将 pe 注册为 buffer,这样它就不会被视为模型参数,但会随模型移动(例如 to(device))
|
|
|
|
|
+ self.register_buffer('pe', pe.unsqueeze(0))
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ # x.size(1) 是当前输入的序列长度
|
|
|
|
|
+ # 将位置编码加到输入向量上
|
|
|
|
|
+ x = x + self.pe[:, :x.size(1)]
|
|
|
|
|
+ return self.dropout(x)
|
|
|
|
|
+
|
|
|
|
|
+class EncoderLayer(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 编码器核心层
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, d_model, num_heads, d_ff, dropout):
|
|
|
|
|
+ super(EncoderLayer, self).__init__()
|
|
|
|
|
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
|
|
|
|
|
+ self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
|
|
|
|
|
+ self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
+ self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, mask):
|
|
|
|
|
+ # 1. 多头自注意力
|
|
|
|
|
+ attn_output = self.self_attn(x, x, x, mask)
|
|
|
|
|
+ x = self.norm1(x + self.dropout(attn_output))
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 前馈网络
|
|
|
|
|
+ ff_output = self.feed_forward(x)
|
|
|
|
|
+ x = self.norm2(x + self.dropout(ff_output))
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+class DecoderLayer(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 解码器核心层
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, d_model, num_heads, d_ff, dropout):
|
|
|
|
|
+ super(DecoderLayer, self).__init__()
|
|
|
|
|
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
|
|
|
|
|
+ self.cross_attn = MultiHeadAttention(d_model, num_heads)
|
|
|
|
|
+ self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
|
|
|
|
|
+ self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
+ self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
+ self.norm3 = nn.LayerNorm(d_model)
|
|
|
|
|
+ self.dropout = nn.Dropout(dropout)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
|
|
|
|
|
+ # 1. 掩码多头自注意力 (对自己)
|
|
|
|
|
+ attn_output = self.self_attn(x, x, x, tgt_mask)
|
|
|
|
|
+ x = self.norm1(x + self.dropout(attn_output))
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 交叉注意力 (对编码器输出)
|
|
|
|
|
+ cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
|
|
|
|
|
+ x = self.norm2(x + self.dropout(cross_attn_output))
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 前馈网络
|
|
|
|
|
+ ff_output = self.feed_forward(x)
|
|
|
|
|
+ x = self.norm3(x + self.dropout(ff_output))
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+class Encoder(nn.Module):
|
|
|
|
|
+ def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
|
|
|
|
|
+ super(Encoder, self).__init__()
|
|
|
|
|
+ self.embedding = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
+ self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
|
|
|
|
|
+ self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
|
|
|
|
|
+ self.norm = nn.LayerNorm(d_model)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, mask):
|
|
|
|
|
+ x = self.embedding(x)
|
|
|
|
|
+ x = self.pos_encoder(x)
|
|
|
|
|
+ for layer in self.layers:
|
|
|
|
|
+ x = layer(x, mask)
|
|
|
|
|
+ return self.norm(x)
|
|
|
|
|
+
|
|
|
|
|
+class Decoder(nn.Module):
|
|
|
|
|
+ def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
|
|
|
|
|
+ super(Decoder, self).__init__()
|
|
|
|
|
+ self.embedding = nn.Embedding(vocab_size, d_model)
|
|
|
|
|
+ self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
|
|
|
|
|
+ self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
|
|
|
|
|
+ self.norm = nn.LayerNorm(d_model)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
|
|
|
|
|
+ x = self.embedding(x)
|
|
|
|
|
+ x = self.pos_encoder(x)
|
|
|
|
|
+ for layer in self.layers:
|
|
|
|
|
+ x = layer(x, encoder_output, src_mask, tgt_mask)
|
|
|
|
|
+ return self.norm(x)
|
|
|
|
|
+
|
|
|
|
|
+class Transformer(nn.Module):
|
|
|
|
|
+ def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len=5000):
|
|
|
|
|
+ super(Transformer, self).__init__()
|
|
|
|
|
+ self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
|
|
|
|
|
+ self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
|
|
|
|
|
+ self.final_linear = nn.Linear(d_model, tgt_vocab_size)
|
|
|
|
|
+
|
|
|
|
|
+ def generate_mask(self, src, tgt):
|
|
|
|
|
+ # src_mask: (batch_size, 1, 1, src_len)
|
|
|
|
|
+ src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
|
|
|
|
|
+
|
|
|
|
|
+ # tgt_mask: (batch_size, 1, tgt_len, tgt_len)
|
|
|
|
|
+ tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)
|
|
|
|
|
+ tgt_len = tgt.size(1)
|
|
|
|
|
+ # 下三角矩阵,用于防止看到未来的 token
|
|
|
|
|
+ tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool() # (tgt_len, tgt_len)
|
|
|
|
|
+ tgt_mask = tgt_pad_mask & tgt_sub_mask
|
|
|
|
|
+
|
|
|
|
|
+ return src_mask, tgt_mask
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, src, tgt):
|
|
|
|
|
+ src_mask, tgt_mask = self.generate_mask(src, tgt)
|
|
|
|
|
+
|
|
|
|
|
+ encoder_output = self.encoder(src, src_mask)
|
|
|
|
|
+ decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
|
|
|
|
|
+
|
|
|
|
|
+ output = self.final_linear(decoder_output)
|
|
|
|
|
+ return output
|
|
|
|
|
+
|
|
|
|
|
+# --- 演示如何使用模型 ---
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ # 1. 定义超参数
|
|
|
|
|
+ src_vocab_size = 5000
|
|
|
|
|
+ tgt_vocab_size = 5000
|
|
|
|
|
+ d_model = 512
|
|
|
|
|
+ num_layers = 6
|
|
|
|
|
+ num_heads = 8
|
|
|
|
|
+ d_ff = 2048
|
|
|
|
|
+ dropout = 0.1
|
|
|
|
|
+ max_len = 100
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 实例化模型
|
|
|
|
|
+ model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 创建模拟输入数据
|
|
|
|
|
+ # 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12
|
|
|
|
|
+ src = torch.randint(1, src_vocab_size, (2, 10)) # (batch_size, seq_length)
|
|
|
|
|
+ tgt = torch.randint(1, tgt_vocab_size, (2, 12)) # (batch_size, seq_length)
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 模型前向传播
|
|
|
|
|
+ output = model(src, tgt)
|
|
|
|
|
+
|
|
|
|
|
+ # 5. 打印输出形状
|
|
|
|
|
+ print("模型输出的形状:", output.shape)
|
|
|
|
|
+ # 预期输出: torch.Size([2, 12, 5000]) -> (batch_size, tgt_seq_len, tgt_vocab_size)
|