Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelSpeedup error: assert len(set(num_channels_list)) == 1, possible incorrect layers in dependency set #5736

Open
saravanabalagi opened this issue Jan 17, 2024 · 1 comment · May be fixed by #5751

Comments

@saravanabalagi
Copy link

saravanabalagi commented Jan 17, 2024

ModelSpeedup does not alter the model successfully for a model with 3 successive conv blocks.

Environment:

  • NNI version: 3.0
  • Python version: 3.8.16
  • PyTorch version: 1.13.0
  • Cpu or cuda version: CUDA 11.6

Reproduce the problem

  • create a model and config with desired sparsity_ratio
  • obtain pruning masks using L1NormPruner
  • call ModelSpeedup with batch_size parameter
Minimal Code
# %%
import torch
import torch.nn as nn

from nni.compression.pruning import L1NormPruner
from nni.compression.utils import auto_set_denpendency_group_ids
from nni.compression.speedup import ModelSpeedup

# %%
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 40, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(40)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(40, 80, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(80)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(80, 1, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x
    
model = ConvNet()
num_params_unpruned = sum(p.numel() for p in model.parameters())
dummy_input = torch.randn(1, 3, 32, 32)
dummy_output = model(dummy_input)
print(dummy_output.shape)

# %%
sparsity_ratio = 0.5
config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': sparsity_ratio,
}]
config_list = auto_set_denpendency_group_ids(model, config_list, [dummy_input])
pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
pruner.unwrap_model()
model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model()

# %%
num_params_pruned = sum(p.numel() for p in model.parameters())
print(f'Number of parameters before pruning: {num_params_unpruned}')
print(f'Number of parameters after pruning: {num_params_pruned}')

num_params_diff = num_params_unpruned - num_params_pruned
prune_ratio = num_params_diff / num_params_unpruned
print(f'Number of parameters pruned: {num_params_diff}')
print(f'Parameter ratio: {(1-prune_ratio)*100:.2f}%')

Error:

Assertion error: number of channels in same set should be identical

Error Trace
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[108], line 1
----> 1 model = ModelSpeedup(model, [dummy_input], masks, garbage_collect_values=False).speedup_model()

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:429, in ModelSpeedup.speedup_model(self)
    427 self.logger.info('Resolve the mask conflict before mask propagate...')
    428 # fix_mask_conflict(self.masks, self.graph_module, self.dummy_input)
--> 429 self.fix_mask_conflict()
    430 self.logger.info('Infer module masks...')
    431 self.initialize_propagate(self.dummy_input)

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/model_speedup.py:243, in ModelSpeedup.fix_mask_conflict(self)
    241 def fix_mask_conflict(self):
    242     fix_group_mask_conflict(self.graph_module, self.masks)
--> 243     fix_channel_mask_conflict(self.graph_module, self.masks)
    244     fix_weight_sharing_mask_conflict(self.graph_module, self.masks)

File /usr/local/lib/python3.8/dist-packages/nni/compression/speedup/mask_conflict.py:296, in fix_channel_mask_conflict(graph_module, masks)
    294 num_channels_list = [len(x) for x in channel_masks if x is not None]
    295 # number of channels in same set should be identical
--> 296 assert len(set(num_channels_list)) == 1
    297 num_channels = num_channels_list[0]
    299 for i, dim_mask in enumerate(channel_masks):

AssertionError: 

The same code works fine without self.conv3 and self.bn3.

@saravanabalagi
Copy link
Author

The error is thrown specifically when the output channels of the last layer is 1, even when there are 2 successive conv blocks:

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 1, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(1)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant