if name != "":
Reparameterization.apply(module, name, dim, reparameterization, hook_child)
else:
names = list(module.state_dict().keys())
for name in names:
apply_reparameterization(module, reparameterization, name, dim, hook_child)
return module
After Change
if name2use != "":
Reparameterization.apply(module, name, dim, reparameterization, hook_child)
else:
names = [n for n,_ in module2use.named_parameters()]
if name2use != "":
names = [name2use+"."+n for n in names]
if name != "":
names = [name+"."+n for n in names]
for name in names:
apply_reparameterization(module, reparameterization, name, dim, hook_child)
return module