count = 0
tgtF = open(opt.tgt) if opt.tgt else None
for line in open(opt.src):
srcTokens = line.split()
srcBatch += [srcTokens]
if tgtF:
tgtTokens = tgtF.readline().split() if tgtF else None
tgtBatch += [tgtTokens]
if len(srcBatch) < opt.batch_size:
continue
predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch)
predScoreTotal += sum(score[0] for score in predScore)
predWordsTotal += sum(len(x[0]) for x in predBatch)
if tgtF is not None:
goldScoreTotal += sum(goldScore)
goldWordsTotal += sum(len(x) for x in tgtBatch)
for b in range(len(predBatch)):
count += 1
outF.write(" ".join(predBatch[b][0]) + "\n")
if opt.verbose:
srcSent = " ".join(srcBatch[b])
if translator.tgt_dict.lower:
srcSent = srcSent.lower()
print("SENT %d: %s" % (count, srcSent))
print("PRED %d: %s" % (count, " ".join(predBatch[b][0])))
print("PRED SCORE: %.4f" % predScore[b][0])
After Change
count = 0
tgtF = open(opt.tgt) if opt.tgt else None
for line in addone(open(opt.src)):
if line is not None:
srcTokens = line.split()
srcBatch += [srcTokens]
if tgtF:
tgtTokens = tgtF.readline().split() if tgtF else None
tgtBatch += [tgtTokens]
if len(srcBatch) < opt.batch_size:
continue
else:
// at the end of file, check last batch
if len(srcBatch) == 0:
break
predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch)
predScoreTotal += sum(score[0] for score in predScore)
predWordsTotal += sum(len(x[0]) for x in predBatch)
if tgtF is not None:
goldScoreTotal += sum(goldScore)
goldWordsTotal += sum(len(x) for x in tgtBatch)
for b in range(len(predBatch)):
count += 1
outF.write(" ".join(predBatch[b][0]) + "\n")
outF.flush()
if opt.verbose:
srcSent = " ".join(srcBatch[b])
if translator.tgt_dict.lower:
srcSent = srcSent.lower()
print("SENT %d: %s" % (count, srcSent))
print("PRED %d: %s" % (count, " ".join(predBatch[b][0])))
print("PRED SCORE: %.4f" % predScore[b][0])