Skip to content

Commit

Permalink
fix total sparsity in sa & auto compress pruner bug (#4474)
Browse files Browse the repository at this point in the history
  • Loading branch information
J-shang committed Jan 18, 2022
1 parent 6b8efe3 commit fd8fb78
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def evaluator(model):
trainer(model, optimizer, criterion, i)
evaluator(model)

config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]

# evaluator in 'SimulatedAnnealingPruner' could not be None.
pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm=args.pruning_algo,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional

Expand All @@ -13,6 +14,8 @@
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator

_logger = logging.getLogger(__name__)


class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
Expand All @@ -29,6 +32,13 @@ def __init__(self, total_iteration: int, origin_model: Module, origin_config_lis
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)

def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
return super().reset(model, config_list, masks)

def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature

# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')

self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity()
Expand Down Expand Up @@ -281,7 +286,10 @@ def _update_with_perturbations(self):
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0:
self._temp_config_list.extend(deepcopy(config))
sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
self._temp_config_list.extend(sub_temp_config_list)
self._temp_sparsity_list.append([])
continue
while True:
Expand Down
4 changes: 2 additions & 2 deletions test/ut/compression/v2/test_iterative_pruner_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_lottery_ticket_pruner(self):

def test_simulated_annealing_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
Expand All @@ -107,7 +107,7 @@ def test_simulated_annealing_pruner(self):

def test_auto_compress_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
admm_params = {
'trainer': trainer,
'traced_optimizer': get_optimizer(model),
Expand Down

0 comments on commit fd8fb78

Please sign in to comment.