self.losses = []
for key, val in loss_dictionary.items():
if "class" in val:
loss_class = getattr(loss, val["class"])
else:
loss_class = getattr(loss, key)
weight = 1 if "weight" not in val else val["weight"]
keys = loss_class.DEFAULT_KEYS if "keys" not in val else val["keys"]
args = [] if "args" not in val else val["args"]
kwargs = {} if "kwargs" not in val else val["kwargs"]
After Change
loss_class = getattr(loss, _loss_name)
weight = 1 if "weight" not in val else val["weight"]
keys = loss_class.DEFAULT_KEYS if "keys" not in val else val["keys"]
args = [] if "args" not in val else copy.deepcopy(val["args"])
kwargs = {} if "kwargs" not in val else copy.deepcopy(val["kwargs"])
if _loss_name in ["CombinationInvariantLoss", "PermutationInvariantLoss"]:
args[0] = getattr(loss, args[0])()