decode = self.decoder.get_output(train)
if self.tie_weights:
encoder_params = self.encoder.get_weights()
decoder_params = self.decoder.get_weights()
for dec_param, enc_param in zip(decoder_params, encoder_params):
if len(dec_param.shape) > 1:
enc_param = dec_param.T
return decode
def get_config(self):
After Change
if not train and not self.output_reconstruction:
return self._get_hidden(train)
decode = self.decoders[-1].get_output(train)
if self.tie_weights:
for e,d in zip(self.encoders, self.decoders):