|
|
@@ -455,7 +455,8 @@ class PositionalEncoding(nn.Module):
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
|
|
|
|
# pe (positional encoding) 的大小为 (max_len, d_model)
|
|
|
- pe = torch.zeros(max_len, 1, d_model)
|
|
|
+ # pe = torch.zeros(max_len, 1, d_model)
|
|
|
+ pe = torch.zeros(1, max_len, d_model)
|
|
|
|
|
|
# 偶数维度使用 sin, 奇数维度使用 cos
|
|
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|