// as we need to predict the eos as well, the future window at that point is N past it
// Step through the sentence
losses = []
for i, word in enumerate(sent):
for j in range(1, N + 1):
for direction in [-1, 1]:
c = torch.tensor([word]).type(type) // This is tensor for center word
context_id = sent[i + direction * j] if 0 <= i + direction * j < len(sent) else S
context = torch.tensor([context_id]).type(type) // Tensor for context word
logits = model(c)
loss = criterion(logits, context)
losses.append(loss)
return torch.stack(losses).sum()
MAX_LEN = 100