@parameterized.parameters((False, False), (False, True), (True, False),
(True, True))
def testInitialState(self, trainable, use_custom_initial_value):
batch_size = 3
hidden1_size = 4
hidden2_size = 5
output1_size = 6
output2_size = 7
initializer = None
if use_custom_initial_value:
initializer = [tf.constant_initializer(8),
tf.constant_initializer(9)]
// Test that the initial state of a non-recurrent DeepRNN is an empty list.
non_recurrent_cores = [snt.Linear(output_size=output1_size),
snt.Linear(output_size=output2_size)]
dummy_deep_rnn = snt.DeepRNN(non_recurrent_cores, skip_connections=False)
dummy_initial_state = dummy_deep_rnn.initial_state(
batch_size, trainable=trainable)
self.assertFalse(dummy_initial_state)
// Test that the initial state of a recurrent DeepRNN is the same as calling
// all cores" initial_state method.
cores = [snt.VanillaRNN(hidden_size=hidden1_size),
snt.VanillaRNN(hidden_size=hidden2_size)]
deep_rnn = snt.DeepRNN(cores)
initial_state = deep_rnn.initial_state(batch_size, trainable=trainable,
trainable_initializers=initializer)
expected_initial_state = []
for i, core in enumerate(cores):
with tf.variable_scope("core-%d" % i):
expected_initializer = None
if initializer:
expected_initializer = initializer[i]
expected_initial_state.append(
core.initial_state(batch_size, trainable=trainable,
trainable_initializers=expected_initializer))
self.evaluate(tf.global_variables_initializer())
initial_state_value = self.evaluate(initial_state)
expected_initial_state_value = self.evaluate(expected_initial_state)
for expected_value, actual_value in zip(expected_initial_state_value,
initial_state_value):
self.assertAllEqual(actual_value, expected_value)