// for each seq-batch-feature combination
x = [x[i:i + self.context + 1] for i in range(seq_len)] // TxLxNxH - sequence, context, batch, feature
x = torch.stack(x)
x = x.permute(0, 2, 3, 1) // TxNxHxL - sequence, batch, feature, context
x = torch.mul(x, self.weight).sum(dim=3)
return x
After Change
groups=self.n_features, padding=0, bias=None)
def forward(self, x):
x = x.transpose(0, 1).transpose(1, 2)
x = F.pad(x, pad=self.pad, value=0)
x = self.conv(x)
x = x.transpose(1, 2).transpose(0, 1).contiguous()