if num_classes is None:
// as far as we expect the outputs/targets tensors to be int64
// we could find number of classes as max available number
num_classes = max(
int(outputs.max().detach().item() + 1),
int(targets.max().detach().item() + 1),
)
tn = torch.zeros((num_classes,), device=outputs.device)
fp = torch.zeros((num_classes,), device=outputs.device)
fn = torch.zeros((num_classes,), device=outputs.device)