// Get index of image
image_index = max_detections.new(max_detections.size(0), 1).fill_(image_i)
// Repeat the batch_id for as many detections of the class cls in the image
max_detections = torch.cat((image_index, max_detections), 1)
// Add max detections to outputs
output = max_detections if output is None else torch.cat((output, max_detections))
return output
After Change
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
for image_i, image_pred in enumerate(prediction):
// Filter out confidence scores below threshold
conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
image_pred = image_pred[conf_mask]