|
|
@@ -453,19 +453,19 @@ class PositionalEncoding(nn.Module):
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
|
|
|
|
# pe (positional encoding) size is (max_len, d_model)
|
|
|
- pe = torch.zeros(1, max_len, d_model)
|
|
|
+ pe = torch.zeros(max_len, d_model)
|
|
|
|
|
|
# Even dimensions use sin, odd dimensions use cos
|
|
|
- pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
|
- pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
|
+ pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
+ pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
|
|
# Register pe as buffer, so it won't be treated as model parameter but will move with the model (e.g., to(device))
|
|
|
- self.register_buffer('pe', pe)
|
|
|
+ self.register_buffer('pe', pe.unsqueeze(0))
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
- # x.size(0) is the current input sequence length
|
|
|
+ # x.size(1) is the current input sequence length
|
|
|
# Add positional encoding to input vector
|
|
|
- x = x + self.pe[:x.size(0)]
|
|
|
+ x = x + self.pe[:, :x.size(1)]
|
|
|
return self.dropout(x)
|
|
|
```
|
|
|
|