f = ONMTDataset.collect_features(fields)
cat = [batch.src[0]] + [batch.__dict__[k] for k in f]
cat = [c.unsqueeze(2) for c in cat]
return torch.cat(cat, 2)
def join_dicts(*args):
After Change
assert side in ["src", "tgt"]
if isinstance(batch.__dict__[side], tuple):
data = batch.__dict__[side][0]
else:
data = batch.__dict__[side]
feat_start = side + "_feat_"
features = sorted(batch.__dict__[k]