return mask
// if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
After Change
return mask
// if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else: