else:
data = batch.__dict__[side]
feat_start = side + "_feat_"
features = sorted(batch.__dict__[k]
for k in batch.__dict__ if feat_start in k)
levels = [data] + features
return torch.cat([level.unsqueeze(2) for level in levels], 2)
After Change
else:
data = batch.__dict__[side]
feat_start = side + "_feat_"
keys = sorted([k for k in batch.__dict__ if feat_start in k])
features = [batch.__dict__[k] for k in keys]
levels = [data] + features
return torch.cat([level.unsqueeze(2) for level in levels], 2)