y = layers.Input(shape=(n_class,))
masked = Mask()([digitcaps, y]) // The true label is used to mask the output of capsule layer.
x_recon = layers.Dense(512, activation="relu")(masked)
x_recon = layers.Dense(1024, activation="relu")(x_recon)
x_recon = layers.Dense(np.prod(input_shape), activation="sigmoid")(x_recon)
x_recon = layers.Reshape(target_shape=input_shape, name="out_recon")(x_recon)
// two-input-two-output keras Model
After Change
// Decoder network.
y = layers.Input(shape=(n_class,))
masked_by_y = Mask()([digitcaps, y]) // The true label is used to mask the output of capsule layer. For training
masked = Mask()(digitcaps) // Mask using the capsule with maximal length. For prediction
// Shared Decoder model in training and prediction
decoder = models.Sequential(name="decoder")
decoder.add(layers.Dense(512, activation="relu", input_dim=16*n_class))
decoder.add(layers.Dense(1024, activation="relu"))
decoder.add(layers.Dense(np.prod(input_shape), activation="sigmoid"))
decoder.add(layers.Reshape(target_shape=input_shape, name="out_recon"))
// Models for training and evaluation (prediction)
train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
eval_model = models.Model(x, [out_caps, decoder(masked)])
return train_model, eval_model
def margin_loss(y_true, y_pred):