seq_len = input.size(0)
// pad the 0th dimension (T/sequence) with zeroes whose number = context
// Once pytorch"s padding functions have settled, should move to those.
padding = torch.zeros(self.context, *(input.size()[1:])).type_as(input)
x = torch.cat((input, padding), 0)
// add lookahead windows (with context+1 width) as a fourth dimension
// for each seq-batch-feature combination
x = [x[i:i + self.context + 1] for i in range(seq_len)] // TxLxNxH - sequence, context, batch, feature
x = torch.stack(x)
x = x.permute(0, 2, 3, 1) // TxNxHxL - sequence, batch, feature, context
After Change
x = x.transpose(0, 1).transpose(1, 2)
x = F.pad(x, pad=self.pad, value=0)
x = self.conv(x)
x = x.transpose(1, 2).transpose(0, 1).contiguous()
return x
def __repr__(self):
return self.__class__.__name__ + "(" \