2de52a8976971da4836727ba9242fedcc7474878,src/sdk/pynni/nni/compression/torch/pruners.py,SlimPruner,calc_mask,#SlimPruner#Any#Any#,202

Before Change


        

        weight = layer.module.weight.data
        op_name = layer.name
        op_type = layer.type
        assert op_type == "BatchNorm2d", "SlimPruner only supports 2d batch normalization layer pruning"
        if op_name in self.mask_calculated_ops:
            assert op_name in self.mask_dict
            return self.mask_dict.get(op_name)
        base_mask = torch.ones(weight.size()).type_as(weight).detach()
        mask = {"weight": base_mask.detach(), "bias": base_mask.clone().detach()}
        try:
            filters = weight.size(0)

After Change


        op_type = layer.type
        if_calculated = kwargs["if_calculated"]
        assert op_type == "BatchNorm2d", "SlimPruner only supports 2d batch normalization layer pruning"
        if if_calculated:
            return None
        base_mask = torch.ones(weight.size()).type_as(weight).detach()
        mask = {"weight": base_mask.detach(), "bias": base_mask.clone().detach()}
        filters = weight.size(0)
        num_prune = int(filters * config.get("sparsity"))
        if filters >= 2 and num_prune >= 1:
            w_abs = weight.abs()
            mask_weight = torch.gt(w_abs, self.global_threshold).type_as(weight)
            mask_bias = mask_weight.clone()
            mask = {"weight": mask_weight.detach(), "bias": mask_bias.detach()}
        if_calculated.copy_(torch.tensor(True)) // pylint: disable=not-callable
        return mask

class LotteryTicketPruner(Pruner):
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: microsoft/nni
Commit Name: 2de52a8976971da4836727ba9242fedcc7474878
Time: 2020-01-16
Author: 656569648@qq.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: SlimPruner
Method Name: calc_mask


Project Name: microsoft/nni
Commit Name: c7d58033db0e25736d33406ed262cb5232d366e8
Time: 2020-02-09
Author: 38930155+chicm-ms@users.noreply.github.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: AGP_Pruner
Method Name: calc_mask


Project Name: microsoft/nni
Commit Name: 4e21e721a65d0ac7c8465c6b7842dd39338bb3d0
Time: 2020-02-09
Author: 656569648@qq.com
File Name: src/sdk/pynni/nni/compression/torch/pruners.py
Class Name: LevelPruner
Method Name: calc_mask