Read in multiple models for ensemble
shared_fields = None
shared_model_opt = None
models = []
for model_path in opt.models:
fields, model, model_opt = \
onmt.model_builder.load_test_model(opt,
dummy_opt,
model_path=model_path)
if shared_fields is None:
shared_fields = fields
else:
for key, field in fields.items():
if field is not None and "vocab" in field.__dict__:
assert field.vocab.stoi == shared_fields[key].vocab.stoi, \
"Ensemble models must use the same preprocessed data"
models.append(model)
if shared_model_opt is None:
shared_model_opt = model_opt
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
return shared_fields, ensemble_model, shared_model_opt