model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(name, root=root)
if name == "gpt2_1558M":
// skip gpt2 1558M due to the space
return
net = model_cls.from_cfg(cfg)
net.load_parameters(local_params_path)
net.hybridize()
num_params, num_fixed_params = count_parameters(net.collect_params())
After Change
out = net(inputs, valid_length, inputs, valid_length)
elif "gpt2" in name:
states = net.init_states(batch_size=batch_size, ctx=ctx)
out, new_states = net(inputs, states)
out_np = out.asnumpy()
else:
out = net(inputs, token_types, valid_length)
mx.npx.waitall()