// and context vector
rnn_output = rnn_output.squeeze(0) // S(=1) x B x I -> B x I
context = context.squeeze(1) // B x S(=1) x I -> B x I
output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1)))
// Return final output, hidden state, and attention weights (for
// visualization)
After Change
// and context vector
rnn_output = rnn_output.squeeze(0) // S(=1) x B x I -> B x I
context = context.squeeze(1) // B x S(=1) x I -> B x I
output = self.out(torch.cat((rnn_output, context), 1))
// Return final output, hidden state, and attention weights (for
// visualization)
return output, context, hidden, attn_weights