weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get("sparsity") < 1, "sparsity must in the range [0, 1)"
assert op_type in ["Conv2d"], "only support Conv2d"
assert op_type in config.get("op_types")
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, "bias") and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else:
After Change
assert 0 <= config.get("sparsity") < 1, "sparsity must in the range [0, 1)"
assert op_type in ["Conv2d"], "only support Conv2d"
assert op_type in config.get("op_types")
if_calculated = kwargs["if_calculated"]if if_calculated:
return None
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
if hasattr(layer.module, "bias") and layer.module.bias is not None:
mask_bias = torch.ones(layer.module.bias.size()).type_as(layer.module.bias).detach()
else: