// if you are using PyTorch newer than 0.4 (e.g., built from
// GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
del state_dict._metadata
// patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): // need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split("."))
After Change
// if you are using PyTorch newer than 0.4 (e.g., built from
// GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
// patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): // need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split("."))
net.load_state_dict(state_dict)