68f970aca1f95cddbc1c5fc4e4c7e6b6bffb2293,models/AttModel.py,AttModel,_sample,#AttModel#Any#Any#Any#Any#,177

Before Change


        seqLogprobs = Variable(fc_feats.data.new(batch_size, self.seq_length).zero_())
        for t in range(self.seq_length + 1):
            if t == 0: // input <bos>
                it = fc_feats.data.new(batch_size).long().zero_()
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data) // fetch prev distribution: shape Nx(M+1)
                else:
                    // scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data, temperature))
                it = torch.multinomial(prob_prev, 1)
                sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) // gather the logprobs at sampled positions
                it = it.view(-1).long() // and flatten indices for downstream processing

            if t >= 1:
                // stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq[:,t-1] = it
                // seq.append(it) //seq[t] the input of t+2 time step

                // seqLogprobs.append(sampleLogprobs.view(-1))
                seqLogprobs[:,t-1] = sampleLogprobs.view(-1)

            it = Variable(it)
            logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state)
            if decoding_constraint and t > 0:
                tmp = output.data.new(output.size(0), self.vocab_size + 1).zero_()
                tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float("-inf"))

After Change


        seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
        for t in range(self.seq_length + 1):
            if t == 0: // input <bos>
                it = fc_feats.new_zeros(batch_size, dtype=torch.long)
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 8

Instances


Project Name: ruotianluo/self-critical.pytorch
Commit Name: 68f970aca1f95cddbc1c5fc4e4c7e6b6bffb2293
Time: 2018-04-26
Author: rluo@ttic.edu
File Name: models/AttModel.py
Class Name: AttModel
Method Name: _sample


Project Name: ruotianluo/self-critical.pytorch
Commit Name: 68f970aca1f95cddbc1c5fc4e4c7e6b6bffb2293
Time: 2018-04-26
Author: rluo@ttic.edu
File Name: models/AttModel.py
Class Name: AttModel
Method Name: _sample_beam