step_size = shape["in_shape"][2] * shape["in_shape"][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) // pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask
After Change
step_size = shape["in_shape"][2] * shape["in_shape"][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) // pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask