""" Get an attention mask to avoid using the subsequent info."""
assert seq.dim() == 2
attn_shape = (seq.size(0), seq.size(1), seq.size(1))
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
subsequent_mask = torch.from_numpy(subsequent_mask)
if seq.is_cuda:
subsequent_mask = subsequent_mask.cuda()
return subsequent_mask
class Encoder(nn.Module):
""" A encoder model with self attention mechanism. """