if alive_attn is not None else None)
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
// Store finished hypotheses for this batch.
After Change
memory_bank = tile(memory_bank, beam_size, dim=1)
memory_lengths = tile(src_lengths, beam_size)
top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
batch_offset = torch.arange(batch_size, dtype=torch.long)
beam_offset = torch.arange(
0,
batch_size * beam_size,
step=beam_size,
dtype=torch.long,
device=memory_bank.device)
alive_seq = torch.full(
[batch_size * beam_size, 1],
start_token,
dtype=torch.long,
device=memory_bank.device)
alive_attn = None
// Give full probability to the first beam on the first step.
topk_log_probs = (
torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
device=memory_bank.device).repeat(batch_size))
// Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] // noqa: F812
results = {}
results["predictions"] = [[] for _ in range(batch_size)] // noqa: F812
results["scores"] = [[] for _ in range(batch_size)] // noqa: F812
results["attention"] = [[] for _ in range(batch_size)] // noqa: F812
results["gold_score"] = [0] * batch_size
results["batch"] = batch
for step in range(max_length):
decoder_input = alive_seq[:, -1].view(1, -1, 1)
// Decoder forward.
dec_out, dec_states, attn = self.model.decoder(
decoder_input,
memory_bank,
dec_states,
memory_lengths=memory_lengths,
step=step)
// Generator forward.
log_probs = self.model.generator.forward(dec_out.squeeze(0))
vocab_size = log_probs.size(-1)
if step < min_length:
log_probs[:, end_token] = -1e20
// Multiply probs by the beam probability.
log_probs += topk_log_probs.view(-1).unsqueeze(1)
alpha = self.global_scorer.alpha
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
// Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
// Recover log probs.
topk_log_probs = topk_scores * length_penalty
// Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size)
topk_ids = topk_ids.fmod(vocab_size)
// Map beam_index to batch_index in the flat representation.
batch_index = (
topk_beam_index
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
select_indices = batch_index.view(-1)
// Append last prediction.
alive_seq = torch.cat(
[alive_seq.index_select(0, select_indices),
topk_ids.view(-1, 1)], -1)
if return_attention:
current_attn = attn["std"].index_select(1, select_indices)
if alive_attn is None:
alive_attn = current_attn
else:
alive_attn = alive_attn.index_select(1, select_indices)
alive_attn = torch.cat([alive_attn, current_attn], 0)
is_finished = topk_ids.to("cpu").eq(end_token)
if step + 1 == max_length:
is_finished.fill_(1)
top_beam_finished |= is_finished[:, 0].eq(1)
// Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
attention = (
alive_attn.view(
alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
if alive_attn is not None else None)
non_finished_batch = []
for i in range(is_finished.size(0)):
b = batch_offset[i]
finished_hyp = is_finished[i].nonzero().view(-1)
// Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append((
topk_scores[i, j],
predictions[i, j, 1:], // Ignore start_token.
attention[:, i, j, :memory_lengths[i]]
if attention is not None else None))
// End condition is the top beam finished and we can return
// n_best hypotheses.
if top_beam_finished[i] and len(hypotheses[b]) >= n_best:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
if n >= n_best:
break
results["scores"][b].append(score)
results["predictions"][b].append(pred)
results["attention"][b].append(
attn if attn is not None else [])
else:
non_finished_batch.append(i)
non_finished = torch.tensor(non_finished_batch)
// If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
// Remove finished batches for the next step.
top_beam_finished = top_beam_finished.index_select(
0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
non_finished = non_finished.to(topk_ids.device)
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)