words = []
for f in src_sent:
word = translator.fields["src"].vocab.itos[f]
if word == onmt.IO.PAD_WORD:
break
words.append(word)
os.write(1, bytes("\nSENT %d: %s\n" %
(count, " ".join(words)), "UTF-8"))
After Change
(sent.squeeze(1) for sent in src.split(1, dim=1)))
for pred_sents, gold_sent, pred_score, gold_score, src_sent in z_batch:
n_best_preds = [" ".join(pred) for pred in pred_sents[:opt.n_best]]
count += 1
out_file.write("\n".join(n_best_preds))
out_file.write("\n")