if self._skip_connections and self._concat_final_output_if_skip:
flat_outputs = tuple(nest.flatten(output) for output in outputs)
flat_outputs = [tf.concat(output, 1) for output in zip(*flat_outputs)]
output = nest.pack_sequence_as(structure=outputs[0],
flat_sequence=flat_outputs)
else:
After Change
next_states = []
outputs = []
recurrent_idx = 0
concatenate = lambda *args: tf.concat(args, axis=-1)
for i, core in enumerate(self._cores):
if self._skip_connections and i > 0:
current_input = nest.map_structure(concatenate, inputs, current_input)
// Determine if this core in the stack is recurrent or not and call
// accordingly.
if self._is_recurrent_list[i]:
current_input, next_state = core(current_input,
prev_state[recurrent_idx])
next_states.append(next_state)
recurrent_idx += 1
else:
current_input = core(current_input)
if self._skip_connections:
outputs.append(current_input)
if self._skip_connections and self._concat_final_output_if_skip:
output = nest.map_structure(concatenate, *outputs)
else:
output = current_input
self._last_output_size = _get_shape_without_batch_dimension(output)