def load_test_model(opt, dummy_opt):
Read in multiple models for ensemble
shared_fields = Noneshared_model_opt = Nonemodels = []
for model_path in opt.models:
fields, model, model_opt = \
onmt.model_builder.load_test_model(opt,
dummy_opt,
model_path=model_path)
if shared_fields is None:
shared_fields = fields
else:
for key, field in fields.items():
if field is not None and "vocab" in field.__dict__:
assert field.vocab.stoi == shared_fields[key].vocab.stoi, \
"Ensemble models must use the same preprocessed data"
models.append(model)
if shared_model_opt is None:
shared_model_opt = model_opt
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)return shared_fields, ensemble_model, shared_model_opt
After Change
def init_state(self, src, memory_bank, enc_hidden):
See :obj:`RNNDecoderBase.init_state()`
for i, model_decoder in enumerate(self.model_decoders):
model_decoder.init_state(src, memory_bank[i], enc_hidden[i])
def map_state(self, fn):
for model_decoder in self.model_decoders:
model_decoder.map_state(fn)
class EnsembleGenerator(nn.Module):
Dummy Generator that delegates to individual real Generators,
and then averages the resulting target distributions.
def __init__(self, model_generators, raw_probs=False):
super(EnsembleGenerator, self).__init__()
self.model_generators = nn.ModuleList(model_generators)
self._raw_probs = raw_probs
def forward(self, hidden, attn=None, src_map=None):
Compute a distribution over the target dictionary
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
distributions = torch.stack(