7a6e3b93fb4b97af7b06244b768b1fee4b547c17,ch12/train_scst.py,,,#,44

Before Change


    log.info("Model loaded from %s, continue training in RL mode...", args.load)

    // BEGIN token
    beg_token = Variable(torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]))
    if args.cuda:
        beg_token = beg_token.cuda()

    with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
        optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
        batch_idx = 0
        best_bleu = None
        for epoch in range(MAX_EPOCHES):
            random.shuffle(train_data)
            dial_shown = False

            total_samples = 0
            skipped_samples = 0
            bleus_argmax = []
            bleus_sample = []

            for batch in data.iterate_batches(train_data, BATCH_SIZE):
                batch_idx += 1
                optimiser.zero_grad()
                input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, net.emb, cuda=args.cuda)
                enc = net.encode(input_seq)

                net_policies = []
                net_actions = []
                net_advantages = []
                beg_embedding = net.emb(beg_token)

                for idx, inp_idx in enumerate(input_batch):
                    total_samples += 1
                    ref_indices = [
                        indices[1:]
                        for indices in output_batch[idx]
                    ]
                    item_enc = net.get_encoded_item(enc, idx)
                    r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
                                                                stop_at_token=end_token)
                    argmax_bleu = utils.calc_bleu_many(actions, ref_indices)
                    bleus_argmax.append(argmax_bleu)

                    if not args.disable_skip and argmax_bleu > 0.99:
                        skipped_samples += 1
                        continue

                    if not dial_shown:
                        log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
                        ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices]
                        log.info("Refer: %s", " ~~|~~ ".join(ref_words))
                        log.info("Argmax: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
                                 argmax_bleu)

                    for _ in range(args.samples):
                        r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding,
                                                                      data.MAX_TOKENS, stop_at_token=end_token)
                        sample_bleu = utils.calc_bleu_many(actions, ref_indices)

                        if not dial_shown:
                            log.info("Sample: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
                                     sample_bleu)

                        net_policies.append(r_sample)
                        net_actions.extend(actions)
                        net_advantages.extend([sample_bleu - argmax_bleu] * len(actions))
                        bleus_sample.append(sample_bleu)
                    dial_shown = True

                if not net_policies:
                    continue

                policies_v = torch.cat(net_policies)
                actions_t = torch.LongTensor(net_actions)
                adv_v = Variable(torch.FloatTensor(net_advantages))
                if args.cuda:
                    actions_t = actions_t.cuda()
                    adv_v = adv_v.cuda()

                log_prob_v = F.log_softmax(policies_v, dim=1)
                log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t]
                loss_policy_v = -log_prob_actions_v.mean()

After Change


    log.info("Model loaded from %s, continue training in RL mode...", args.load)

    // BEGIN token
    beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)

    with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
        optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 17

Instances


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 7a6e3b93fb4b97af7b06244b768b1fee4b547c17
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch12/train_scst.py
Class Name:
Method Name:


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 7a6e3b93fb4b97af7b06244b768b1fee4b547c17
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch12/train_scst.py
Class Name:
Method Name:


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 1e9c3ee592be5e11dcce932a73009488d6f85474
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch17/02_imag.py
Class Name:
Method Name:


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 7a6e3b93fb4b97af7b06244b768b1fee4b547c17
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch12/train_crossent.py
Class Name:
Method Name: