A tensor of the same shape as `predictions`.
predictions, gold_labels, mask = self.detach_tensors(predictions, gold_labels, mask)
device = gold_labels.device
// Some sanity checks.
if gold_labels.size() != predictions.size():
raise ValueError(
f"gold_labels must have shape == predictions.size() but "
f"found tensor of shape: {gold_labels.size()}"
)
if mask is not None and mask.size() != predictions.size():
raise ValueError(
f"mask must have shape == predictions.size() but "
f"found tensor of shape: {mask.size()}"
)
batch_size = predictions.size(0)
if mask is not None:
// We can multiply by the mask up front, because we"re just checking equality below, and
// this way everything that"s masked will be equal.
predictions = predictions * mask
gold_labels = gold_labels * mask
// We want to skip predictions that are completely masked;
// so we"ll keep predictions that aren"t.
keep = mask.view(batch_size, -1).max(dim=1)[0]
else:
keep = torch.ones(batch_size, device=predictions.device).bool()
predictions = predictions.view(batch_size, -1)
gold_labels = gold_labels.view(batch_size, -1)
// At this point, predictions is (batch_size, rest_of_dims_combined),
// so .eq -> .prod will be 1 if every element of the instance prediction is correct
// and 0 if at least one element of the instance prediction is wrong.
// Because of how we"re handling masking, masked positions are automatically "correct".
correct = predictions.eq(gold_labels).prod(dim=1).float()
// Since masked positions are correct, we need to explicitly exclude instance predictions
// where the entire prediction is masked (because they look "correct").
self._correct_count += (correct * keep).sum()
self._total_count += keep.sum()
if is_distributed():
_correct_count = torch.tensor(self._correct_count).to(device)
_total_count = torch.tensor(self._total_count).to(device)
dist.all_reduce(_correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_count, op=dist.ReduceOp.SUM)
self._correct_count = _correct_count.item()
self._total_count = _total_count.item()
def get_metric(self, reset: bool = False):
// Returns