// When fine-tuning the detector , the image_text_alignment is
// sometimes padded too long.
if position_embeddings_visual.size(1) != visual_embeddings.size(1):
assert position_embeddings_visual.size(1) >= visual_embeddings.size(1)
position_embeddings_visual = position_embeddings_visual[
:, : visual_embeddings.size(1), :
After Change
// We want to averge along the alignment_number dimension.
image_text_alignment_mask = image_text_alignment_mask.sum(2)
image_text_alignment_mask[image_text_alignment_mask == 0] = torch.tensor(
[1], dtype=torch.long
) // Avoid devide by zero error
position_embeddings_visual = (
position_embeddings_visual / image_text_alignment_mask.unsqueeze(-1)