collective_fn,
postprocess_fn=postprocess_fn)
def reducescatter(self,
tensor,
tensor_list,
reducescatter_options=ReduceScatterOptions()):
Reducescatter a list of tensors across the group.
Args:
tensor: the output tensor (could be unspecified).
tensor_list: the list of tensor to be reduced then scattered.
reducescatter_options: reducescatter options.
Returns:
None
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.reduceScatter(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(output_tensor),
nccl_util.get_nccl_tensor_dtype(output_tensor),
nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp),
stream.ptr)
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
flattened_input_tensor = _flatten_for_scatter_gather(
tensor_list, copy=False)
def preprocess_fn(stream):
for i, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(flattened_input_tensor[i], tensor)
self._collective(
flattened_input_tensor,
tensor,
collective_fn,
preprocess_fn=preprocess_fn)
def send(self, tensor, dst_rank):
Send tensor to a destination process in the group.