for layer in self.layers:
if self.cell_type == "lstm":
hidden, final_memory_state, final_carry_state = \
layer(
hidden,
initial_state=initial_state,
training=training
)
initial_state = [final_memory_state, final_carry_state]
else:
hidden, initial_state = layer(
hidden,
After Change
for layer in self.layers:
outputs = layer(hidden, training=training)
hidden = outputs[0]
return hidden, outputs[1:]
def get_cell_fun(cell_type):
if cell_type == "rnn":
cell_fn = tf.nn.rnn_cell.BasicRNNCell // todo tf2: do we eventually need tf2.keras.layers.SimpleRNNCell