s = 0
while s < n:
t = min(s + batch_size, n)
if t < s + batch_size:
break
batches.append(get_batch(x0[s:t] + x1[s:t],
[0] * (t-s) + [1]*(t-s),
word2id,
batch_size))
After Change
// half as batch size
batch_size = batch_size // 2
n = max(len(x0), len(x1))
n = (n // batch_size + 1) * batch_size
if len(x0) < n:
x0 = makeup(x0, n)
if len(x1) < n:
x1 = makeup(x1, n)
if sort:
order0 = range(n)
z = sorted(zip(order0, x0), key=lambda i:len(i[1]))
order0, x0 = zip(*z)
order1 = range(n)
z = sorted(zip(order1, x1), key=lambda i:len(i[1]))
order1, x1 = zip(*z)
else:
order0 = range(n)
order1 = range(n)
random.shuffle(order0)
random.shuffle(order1)
x0 = [x0[i] for i in order0]
x1 = [x1[i] for i in order1]
batches = []
s = 0
while s < n:
t = s + batch_size
batches.append(get_batch(x0[s:t] + x1[s:t],
[0] * (t-s) + [1]*(t-s),
word2id,
batch_size))
s = t
return batches, order0, order1
def strip_eos(sents):
return [sent[:sent.index("_EOS")] if "_EOS" in sent else sent
for sent in sents]