// Initialize variables
with self.graph.as_default():
self.session.run(tf.global_variables_initializer())
// Load saved checkpoint to populate trained parameters
with self.graph.as_default():
saver = tf.train.Saver(tf.global_variables())
After Change
warnings.warn("Unstable! Please extensively test this part of the code when time permits")
self.load_state_dict(
torch.load("{}/model.params".format(model_dir))
)
if verbose:
print("[{0}] Loaded model <{1}>".format(self.name, model_name))