Ls = []
with autograd.record():
for j, (X, y, m, s, h) in enumerate(zip(data, target, mask, sample, hiddens)):
output, h, new_target = model(X, y, h, s)
output = output.reshape((-3, -1))
new_target = new_target.reshape((-1,))
l = loss(output, new_target) * m.reshape((-1,))
After Change
for _ in range(len(data)):
hidden, ls = parallel.get()
// hidden states are ordered by context id
index = context.index(hidden[0].context)
hiddens[index] = hidden
Ls.append(ls)
// prefetch the next batch of data