log.info("Model loaded from %s, continue training in RL mode...", args.load)
// BEGIN token
beg_token = Variable(torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]))
if args.cuda:
beg_token = beg_token.cuda()
with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
batch_idx = 0
best_bleu = None
for epoch in range(MAX_EPOCHES):
random.shuffle(train_data)
dial_shown = False
total_samples = 0
skipped_samples = 0
bleus_argmax = []
bleus_sample = []
for batch in data.iterate_batches(train_data, BATCH_SIZE):
batch_idx += 1
optimiser.zero_grad()
input_seq, input_batch, output_batch = model.pack_batch_no_out(batch, net.emb, cuda=args.cuda)
enc = net.encode(input_seq)
net_policies = []
net_actions = []
net_advantages = []
beg_embedding = net.emb(beg_token)
for idx, inp_idx in enumerate(input_batch):
total_samples += 1
ref_indices = [
indices[1:]
for indices in output_batch[idx]
]
item_enc = net.get_encoded_item(enc, idx)
r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, data.MAX_TOKENS,
stop_at_token=end_token)
argmax_bleu = utils.calc_bleu_many(actions, ref_indices)
bleus_argmax.append(argmax_bleu)
if not args.disable_skip and argmax_bleu > 0.99:
skipped_samples += 1
continue
if not dial_shown:
log.info("Input: %s", utils.untokenize(data.decode_words(inp_idx, rev_emb_dict)))
ref_words = [utils.untokenize(data.decode_words(ref, rev_emb_dict)) for ref in ref_indices]
log.info("Refer: %s", " ~~|~~ ".join(ref_words))
log.info("Argmax: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
argmax_bleu)
for _ in range(args.samples):
r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding,
data.MAX_TOKENS, stop_at_token=end_token)
sample_bleu = utils.calc_bleu_many(actions, ref_indices)
if not dial_shown:
log.info("Sample: %s, bleu=%.4f", utils.untokenize(data.decode_words(actions, rev_emb_dict)),
sample_bleu)
net_policies.append(r_sample)
net_actions.extend(actions)
net_advantages.extend([sample_bleu - argmax_bleu] * len(actions))
bleus_sample.append(sample_bleu)
dial_shown = True
if not net_policies:
continue
policies_v = torch.cat(net_policies)
actions_t = torch.LongTensor(net_actions)
adv_v = Variable(torch.FloatTensor(net_advantages))
if args.cuda:
actions_t = actions_t.cuda()
adv_v = adv_v.cuda()
log_prob_v = F.log_softmax(policies_v, dim=1)
log_prob_actions_v = adv_v * log_prob_v[range(len(net_actions)), actions_t]
loss_policy_v = -log_prob_actions_v.mean()
After Change
log.info("Model loaded from %s, continue training in RL mode...", args.load)
// BEGIN token
beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)