if return_transform is None:
return_transform = self.return_transform
if params is None:
params = self.__forward_parameters__(batch_shape, self.p, self.p_batch, self.same_on_batch)
if "batch_prob" not in params:
params["batch_prob"] = torch.tensor([True] * batch_shape[0])
warnings.warn("`batch_prob` is not found in params. Will assume applying on all data.")
self._params = params
output = self.apply_func(in_tensor, in_transform, self._params, return_transform)
return _transform_output_shape(output, ori_shape) if self.keepdim else output
After Change
if return_transform is None:
return_transform = self.return_transform
if params is None:
params = self.forward_parameters(batch_shape)
if "batch_prob" not in params:
params["batch_prob"] = torch.tensor([True] * batch_shape[0])
warnings.warn("`batch_prob` is not found in params. Will assume applying on all data.")
self._params = params
output = self.apply_func(in_tensor, in_transform, self._params, return_transform)
return _transform_output_shape(output, ori_shape) if self.keepdim else output