decOut = self.model.make_init_decoder_output(context)
padMask = srcBatch.data.eq(onmt.Constants.PAD).t() \
.unsqueeze(0) \
.repeat(beamSize, 1, 1)
batchIdx = list(range(batchSize))
remainingSents = batchSize
for i in range(self.opt.max_sent_length):
self.model.decoder.apply(applyContextMask)
// Prepare decoder input.
input = torch.stack([b.getCurrentState() for b in beam
if not b.done]).t().contiguous().view(1, -1)
decOut, decStates, attn = self.model.decoder(
Variable(input, volatile=True), decStates, context, decOut)
// decOut: 1 x (beam*batch) x numWords
decOut = decOut.squeeze(0)
out = self.model.generator.forward(decOut)
// batch x beam x numWords
wordLk = out.view(beamSize, remainingSents, -1) \
.transpose(0, 1).contiguous()
attn = attn.view(beamSize, remainingSents, -1) \
.transpose(0, 1).contiguous()
active = []
for b in range(batchSize):
if beam[b].done:
continue
idx = batchIdx[b]
if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
active += [b]
for decState in decStates: // iterate over h, c
// layers x beam*sent x dim
sentStates = decState.view(-1, beamSize,
remainingSents,
decState.size(2))[:, :, idx]
sentStates.data.copy_(
sentStates.data.index_select(
1, beam[b].getCurrentOrigin()))
if not active:
break
// in this section, the sentences that are still active are
// compacted so that the decoder is not run on completed sentences
activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
batchIdx = {beam: idx for idx, beam in enumerate(active)}
def updateActive(t):
// select only the remaining active sentences
view = t.data.view(-1, remainingSents, rnnSize)
newSize = list(t.size())
newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
return Variable(view.index_select(1, activeIdx)
.view(*newSize), volatile=True)
decStates = (updateActive(decStates[0]),
updateActive(decStates[1]))
decOut = updateActive(decOut)
context = updateActive(context)
padMask = padMask.index_select(1, activeIdx)
remainingSents = len(active)
// (4) package everything up
After Change
if srcBatch[0].dim() == 2:
batchSize = srcBatch[0].size(1)
else:
batchSize = srcBatch[0].size(0)
beamSize = self.opt.beam_size
// (1) run the encoder on the src
encStates, context = self.model.encoder(srcBatch)