def _restore_embed(embed_var, var_to_shape_map, reader):
has_embed = len([var for var in var_to_shape_map if "EmbeddingMatrix" in var]) > 0
if has_embed:
return None, False // assume same name
for var in var_to_shape_map:
if var.endswith("dense/kernel") and var_to_shape_map[var] == tf.transpose(embed_var).shape:
print("Assigning", var, "to", embed_var.name)
return embed_var.assign(reader.get_tensor(var).T), True