Transformer.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. import copy
  5. class MultiHeadAttention(nn.Module):
  6. """
  7. 多头注意力机制模块
  8. """
  9. def __init__(self, d_model, num_heads):
  10. super(MultiHeadAttention, self).__init__()
  11. assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
  12. self.d_model = d_model
  13. self.num_heads = num_heads
  14. self.d_k = d_model // num_heads
  15. # 定义 Q, K, V 和输出的线性变换层
  16. self.W_q = nn.Linear(d_model, d_model)
  17. self.W_k = nn.Linear(d_model, d_model)
  18. self.W_v = nn.Linear(d_model, d_model)
  19. self.W_o = nn.Linear(d_model, d_model)
  20. def scaled_dot_product_attention(self, Q, K, V, mask=None):
  21. # 1. 计算注意力得分 (QK^T)
  22. attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
  23. # 2. 应用掩码 (如果提供)
  24. if mask is not None:
  25. # 将掩码中为 0 的位置设置为一个非常小的负数,这样 softmax 后会接近 0
  26. attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
  27. # 3. 计算注意力权重 (Softmax)
  28. attn_probs = torch.softmax(attn_scores, dim=-1)
  29. # 4. 加权求和 (权重 * V)
  30. output = torch.matmul(attn_probs, V)
  31. return output
  32. def split_heads(self, x):
  33. # 将输入 x 的形状从 (batch_size, seq_length, d_model)
  34. # 变换为 (batch_size, num_heads, seq_length, d_k)
  35. batch_size, seq_length, d_model = x.size()
  36. return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
  37. def combine_heads(self, x):
  38. # 将输入 x 的形状从 (batch_size, num_heads, seq_length, d_k)
  39. # 变回 (batch_size, seq_length, d_model)
  40. batch_size, num_heads, seq_length, d_k = x.size()
  41. return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
  42. def forward(self, Q, K, V, mask=None):
  43. # 1. 对 Q, K, V 进行线性变换
  44. Q = self.split_heads(self.W_q(Q))
  45. K = self.split_heads(self.W_k(K))
  46. V = self.split_heads(self.W_v(V))
  47. # 2. 计算缩放点积注意力
  48. attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
  49. # 3. 合并多头输出并进行最终的线性变换
  50. output = self.W_o(self.combine_heads(attn_output))
  51. return output
  52. class PositionWiseFeedForward(nn.Module):
  53. """
  54. 位置前馈网络模块
  55. """
  56. def __init__(self, d_model, d_ff, dropout=0.1):
  57. super(PositionWiseFeedForward, self).__init__()
  58. self.linear1 = nn.Linear(d_model, d_ff)
  59. self.dropout = nn.Dropout(dropout)
  60. self.linear2 = nn.Linear(d_ff, d_model)
  61. self.relu = nn.ReLU()
  62. def forward(self, x):
  63. # x 形状: (batch_size, seq_len, d_model)
  64. x = self.linear1(x)
  65. x = self.relu(x)
  66. x = self.dropout(x)
  67. x = self.linear2(x)
  68. # 最终输出形状: (batch_size, seq_len, d_model)
  69. return x
  70. class PositionalEncoding(nn.Module):
  71. """
  72. 为输入序列的词嵌入向量添加位置编码。
  73. """
  74. def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
  75. super().__init__()
  76. self.dropout = nn.Dropout(p=dropout)
  77. # 创建一个足够长的位置编码矩阵
  78. position = torch.arange(max_len).unsqueeze(1)
  79. div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
  80. # pe (positional encoding) 的大小为 (max_len, d_model)
  81. pe = torch.zeros(max_len, d_model)
  82. # 偶数维度使用 sin, 奇数维度使用 cos
  83. pe[:, 0::2] = torch.sin(position * div_term)
  84. pe[:, 1::2] = torch.cos(position * div_term)
  85. # 将 pe 注册为 buffer,这样它就不会被视为模型参数,但会随模型移动(例如 to(device))
  86. self.register_buffer('pe', pe.unsqueeze(0))
  87. def forward(self, x: torch.Tensor) -> torch.Tensor:
  88. # x.size(1) 是当前输入的序列长度
  89. # 将位置编码加到输入向量上
  90. x = x + self.pe[:, :x.size(1)]
  91. return self.dropout(x)
  92. class EncoderLayer(nn.Module):
  93. """
  94. 编码器核心层
  95. """
  96. def __init__(self, d_model, num_heads, d_ff, dropout):
  97. super(EncoderLayer, self).__init__()
  98. self.self_attn = MultiHeadAttention(d_model, num_heads)
  99. self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
  100. self.norm1 = nn.LayerNorm(d_model)
  101. self.norm2 = nn.LayerNorm(d_model)
  102. self.dropout = nn.Dropout(dropout)
  103. def forward(self, x, mask):
  104. # 1. 多头自注意力
  105. attn_output = self.self_attn(x, x, x, mask)
  106. x = self.norm1(x + self.dropout(attn_output))
  107. # 2. 前馈网络
  108. ff_output = self.feed_forward(x)
  109. x = self.norm2(x + self.dropout(ff_output))
  110. return x
  111. class DecoderLayer(nn.Module):
  112. """
  113. 解码器核心层
  114. """
  115. def __init__(self, d_model, num_heads, d_ff, dropout):
  116. super(DecoderLayer, self).__init__()
  117. self.self_attn = MultiHeadAttention(d_model, num_heads)
  118. self.cross_attn = MultiHeadAttention(d_model, num_heads)
  119. self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
  120. self.norm1 = nn.LayerNorm(d_model)
  121. self.norm2 = nn.LayerNorm(d_model)
  122. self.norm3 = nn.LayerNorm(d_model)
  123. self.dropout = nn.Dropout(dropout)
  124. def forward(self, x, encoder_output, src_mask, tgt_mask):
  125. # 1. 掩码多头自注意力 (对自己)
  126. attn_output = self.self_attn(x, x, x, tgt_mask)
  127. x = self.norm1(x + self.dropout(attn_output))
  128. # 2. 交叉注意力 (对编码器输出)
  129. cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
  130. x = self.norm2(x + self.dropout(cross_attn_output))
  131. # 3. 前馈网络
  132. ff_output = self.feed_forward(x)
  133. x = self.norm3(x + self.dropout(ff_output))
  134. return x
  135. class Encoder(nn.Module):
  136. def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
  137. super(Encoder, self).__init__()
  138. self.embedding = nn.Embedding(vocab_size, d_model)
  139. self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
  140. self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
  141. self.norm = nn.LayerNorm(d_model)
  142. def forward(self, x, mask):
  143. x = self.embedding(x)
  144. x = self.pos_encoder(x)
  145. for layer in self.layers:
  146. x = layer(x, mask)
  147. return self.norm(x)
  148. class Decoder(nn.Module):
  149. def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len):
  150. super(Decoder, self).__init__()
  151. self.embedding = nn.Embedding(vocab_size, d_model)
  152. self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
  153. self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
  154. self.norm = nn.LayerNorm(d_model)
  155. def forward(self, x, encoder_output, src_mask, tgt_mask):
  156. x = self.embedding(x)
  157. x = self.pos_encoder(x)
  158. for layer in self.layers:
  159. x = layer(x, encoder_output, src_mask, tgt_mask)
  160. return self.norm(x)
  161. class Transformer(nn.Module):
  162. def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len=5000):
  163. super(Transformer, self).__init__()
  164. self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
  165. self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
  166. self.final_linear = nn.Linear(d_model, tgt_vocab_size)
  167. def generate_mask(self, src, tgt):
  168. # src_mask: (batch_size, 1, 1, src_len)
  169. src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
  170. # tgt_mask: (batch_size, 1, tgt_len, tgt_len)
  171. tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, tgt_len)
  172. tgt_len = tgt.size(1)
  173. # 下三角矩阵,用于防止看到未来的 token
  174. tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=src.device)).bool() # (tgt_len, tgt_len)
  175. tgt_mask = tgt_pad_mask & tgt_sub_mask
  176. return src_mask, tgt_mask
  177. def forward(self, src, tgt):
  178. src_mask, tgt_mask = self.generate_mask(src, tgt)
  179. encoder_output = self.encoder(src, src_mask)
  180. decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
  181. output = self.final_linear(decoder_output)
  182. return output
  183. # --- 演示如何使用模型 ---
  184. if __name__ == "__main__":
  185. # 1. 定义超参数
  186. src_vocab_size = 5000
  187. tgt_vocab_size = 5000
  188. d_model = 512
  189. num_layers = 6
  190. num_heads = 8
  191. d_ff = 2048
  192. dropout = 0.1
  193. max_len = 100
  194. # 2. 实例化模型
  195. model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_len)
  196. # 3. 创建模拟输入数据
  197. # 假设 batch_size=2, src_seq_len=10, tgt_seq_len=12
  198. src = torch.randint(1, src_vocab_size, (2, 10)) # (batch_size, seq_length)
  199. tgt = torch.randint(1, tgt_vocab_size, (2, 12)) # (batch_size, seq_length)
  200. # 4. 模型前向传播
  201. output = model(src, tgt)
  202. # 5. 打印输出形状
  203. print("模型输出的形状:", output.shape)
  204. # 预期输出: torch.Size([2, 12, 5000]) -> (batch_size, tgt_seq_len, tgt_vocab_size)