b0c0eb7b1f2ac9a983c550ee971cea275463d8fc,src/sdk/pynni/nni/compression/torch/builtin_pruners.py,AGP_Pruner,calc_mask,#AGP_Pruner#,86

Before Change


        freq = config.get("frequency", 1)
        if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
                and (self.now_epoch - start_epoch) % freq == 0:
            mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight))
            target_sparsity = self.compute_target_sparsity(config)
            k = int(weight.numel() * target_sparsity)
            if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
                return mask
            // if we want to generate new mask, we should update weigth first
            w_abs = weight.abs() * mask
            threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
            new_mask = torch.gt(w_abs, threshold).type_as(weight)
            self.mask_dict.update({op_name: new_mask})

After Change


        freq = config.get("frequency", 1)
        if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
                and (self.now_epoch - start_epoch) % freq == 0:
            mask = self.mask_dict.get(op_name, {"weight": torch.ones(weight.shape).type_as(weight)})
            target_sparsity = self.compute_target_sparsity(config)
            k = int(weight.numel() * target_sparsity)
            if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
                return mask
            // if we want to generate new mask, we should update weigth first
            w_abs = weight.abs() * mask
            threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
            new_mask = {"weight": torch.gt(w_abs, threshold).type_as(weight)}
            self.mask_dict.update({op_name: new_mask})
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: microsoft/nni
Commit Name: b0c0eb7b1f2ac9a983c550ee971cea275463d8fc
Time: 2019-12-23
Author: lanny@mail.hfut.edu.cn
File Name: src/sdk/pynni/nni/compression/torch/builtin_pruners.py
Class Name: AGP_Pruner
Method Name: calc_mask