seqLogprobs = Variable(fc_feats.data.new(batch_size, self.seq_length).zero_())
for t in range(self.seq_length + 1):
if t == 0: // input <bos>
it = fc_feats.data.new(batch_size).long().zero_()
elif sample_max:
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else:
if temperature == 1.0:
prob_prev = torch.exp(logprobs.data) // fetch prev distribution: shape Nx(M+1)
else:
// scale logprobs by temperature
prob_prev = torch.exp(torch.div(logprobs.data, temperature))
it = torch.multinomial(prob_prev, 1)
sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) // gather the logprobs at sampled positions
it = it.view(-1).long() // and flatten indices for downstream processing
if t >= 1:
// stop when all finished
if t == 1:
unfinished = it > 0
else:
unfinished = unfinished * (it > 0)
if unfinished.sum() == 0:
break
it = it * unfinished.type_as(it)
seq[:,t-1] = it
// seq.append(it) //seq[t] the input of t+2 time step
// seqLogprobs.append(sampleLogprobs.view(-1))
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
it = Variable(it)
logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state)
if decoding_constraint and t > 0:
tmp = output.data.new(output.size(0), self.vocab_size + 1).zero_()
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float("-inf"))
After Change
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
for t in range(self.seq_length + 1):
if t == 0: // input <bos>
it = fc_feats.new_zeros(batch_size, dtype=torch.long)
elif sample_max:
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else: