// Add previous memory chunk (as const, w/o gradient) to input.
// Tau (number of (prev) time slices in each memory chunk).
Tau = list(memory.shape)[1] if memory is not None else 0
if memory is not None:
memory.requires_grad_(False)
inputs = torch.cat((memory, inputs), dim=1)
After Change
// Add previous memory chunk (as const, w/o gradient) to input.
// Tau (number of (prev) time slices in each memory chunk).
Tau = list(memory.shape)[1]
inputs = torch.cat((memory.detach(), inputs), dim=1)
// Apply the Layer-Norm.
if self._input_layernorm is not None:
inputs = self._input_layernorm(inputs)