def make_features(batch, fields):
// TODO: This is bit hacky, add to batch somehow.
f = ONMTDataset.collect_features(fields)
cat = [batch.src[0]] + [batch.__dict__[k] for k in f]
cat = [c.unsqueeze(2) for c in cat]
return torch.cat(cat, 2)
After Change
else:
data = batch.__dict__[side]
feat_start = side + "_feat_"
features = sorted(batch.__dict__[k]
for k in batch.__dict__ if feat_start in k)
levels = [data] + features
return torch.cat([level.unsqueeze(2) for level in levels], 2)