// get probability of correct answer, divide by total prob mass
prob_true = probs.get(parsed[i], 0)
if prob_true > 0:
prob_true /= sum(probs.values())
loss -= math.log(prob_true)
else:
loss = float("inf")
num_tokens += 1
After Change
// get probability of correct answer, divide by total prob mass
prob_true = probs.get(parsed[i], 0)
if prob_true > 0:
prob_true /= sum((probs.get(k, 0) for k in self.dict.keys()))
loss -= math.log(prob_true)
else:
loss = float("inf")
num_tokens += 1