if tags is None:
tags = r_arg.tags.copy()
else:
tags.extend([tag for tag in r_arg.tags if tag not in tags])
(
upcasted_args,
upcasted_kwargs,
) = lib.python.util.upcast_args_and_kwargs(resolved_args, resolved_kwargs)
if self.is_static:
result = method(*upcasted_args, **upcasted_kwargs)
else:
// in opacus the step method in torch gets monkey patched on .attach
// this means we can"t use the original AST method reference and need to
// get it again from the actual object so for now lets allow the following
// two methods to be resolved at execution time
method_name = self.path.split(".")[-1]
if method_name in ["step", "zero_grad"]:
// TODO: Remove this Opacus workaround
try:
method = getattr(resolved_self.data, method_name, None)
result = method(*upcasted_args, **upcasted_kwargs)
except Exception as e:
critical(
f"Unable to resolve method {self.path} on {resolved_self}. {e}"
)
result = method(
resolved_self.data, *upcasted_args, **upcasted_kwargs
)
else:
result = method(resolved_self.data, *upcasted_args, **upcasted_kwargs)
// TODO: replace with proper tuple support
if type(result) is tuple:
// convert to list until we support tuples
result = list(result)
result = method(resolved_self.data, *upcasted_args, **upcasted_kwargs)
if lib.python.primitive_factory.isprimitive(value=result):
// Wrap in a SyPrimitive
result = lib.python.primitive_factory.PrimitiveFactory.generate_primitive(
value=result, id=self.id_at_location
)
else:
// TODO: overload all methods to incorporate this automatically
if hasattr(result, "id"):
try:
if hasattr(result, "_id"):
// set the underlying id
result._id = self.id_at_location
else:
result.id = self.id_at_location
assert result.id == self.id_at_location
except AttributeError as e:
err = f"Unable to set id on result {type(result)}. {e}"
traceback_and_raise(Exception(err))
if mutating_internal:
resolved_self.read_permissions = result_read_permissions
if not isinstance(result, StorableObject):
result = StorableObject(
id=self.id_at_location,
data=result,
read_permissions=result_read_permissions,
)
if tags is not None:
result.tags = tags
result.tags.append(self.path.split(".")[-1])
node.store[self.id_at_location] = result
After Change
tag_args.append(r_arg)
resolved_kwargs = {}
tag_kwargs = {}
for arg_name, arg in self.kwargs.items():
r_arg = node.store[arg.id_at_location]
result_read_permissions = self.intersect_keys(
result_read_permissions, r_arg.read_permissions