// Note that we _cannot_ use a reshape here, because this tensor was created// with num_heads being the first dimension, so reshaping naively would not// throw an error, but give an incorrect result.
outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)
// Project back to original input size.// shape (batch_size, timesteps, input_size)