2de52a8976971da4836727ba9242fedcc7474878,src/sdk/pynni/nni/compression/torch/pruners.py,SlimPruner,calc_mask,#SlimPruner#Any#Any#,202
Before Change
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == "BatchNorm2d", "SlimPruner only supports 2d batch normalization layer pruning"
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {"weight": base_mask.detach(), "bias": base_mask.clone().detach()}
try:
filters = weight.size(0)
After Change
op_type = layer.type
if_calculated = kwargs["if_calculated"]
assert op_type == "BatchNorm2d", "SlimPruner only supports 2d batch normalization layer pruning"
if if_calculated:
return None
base_mask = torch.ones(weight.size()).type_as(weight).detach()
mask = {"weight": base_mask.detach(), "bias": base_mask.clone().detach()}
filters = weight.size(0)
num_prune = int(filters * config.get("sparsity"))
if filters >= 2 and num_prune >= 1:
w_abs = weight.abs()
mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
mask_bias = mask_weight.clone()
mask = {"weight": mask_weight.detach(), "bias": mask_bias.detach()}
if_calculated.copy_(torch.tensor(True)) // pylint: disable=not-callable
return mask
class LotteryTicketPruner(Pruner):
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 7
Instances
Project Name: microsoft/nni
Commit Name: 2de52a8976971da4836727ba9242fedcc7474878
Time: 2020-01-16
Author: 656569648@qq.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: SlimPruner
Method Name: calc_mask
Project Name: microsoft/nni
Commit Name: c7d58033db0e25736d33406ed262cb5232d366e8
Time: 2020-02-09
Author: 38930155+chicm-ms@users.noreply.github.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: AGP_Pruner
Method Name: calc_mask
Project Name: microsoft/nni
Commit Name: 4e21e721a65d0ac7c8465c6b7842dd39338bb3d0
Time: 2020-02-09
Author: 656569648@qq.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: LevelPruner
Method Name: calc_mask