max_detections = torch.cat(max_detections)
// Get index of image
image_index = max_detections.new(max_detections.size(0), 1).fill_(image_i)
// Repeat the batch_id for as many detections of the class cls in the image
max_detections = torch.cat((image_index, max_detections), 1)
// Add max detections to outputs