// Keep only the pieces of the state tensors corresponding to the// ancestors created this iteration.for key, state_tensor in state.items():
if state_tensor is None:
continue _, *last_dims = state_tensor.size()
// shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))