// register user specified buffer
for name in self.pruner.buffers:
self.register_buffer(name, self.pruner.buffers[name].clone())
self.registered_buffers[name] = getattr(self, name)
def forward(self, *inputs):
mask = self.pruner.calc_mask(LayerInfo(self.name, self.module), self.config, **self.registered_buffers)
if mask is not None:
After Change
// config and pruner
self.config = config
self.pruner = pruner
self.registered_buffers = []
// register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, "bias") and self.module.bias is not None: