Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when pruning a group with GLU Layer #366

Open
saravanabalagi opened this issue Apr 10, 2024 · 4 comments
Open

Error when pruning a group with GLU Layer #366

saravanabalagi opened this issue Apr 10, 2024 · 4 comments

Comments

@saravanabalagi
Copy link

Pruning a model with GLU results in an error when finding importance. GLU does not have any params but halves the input (in the given dimension). This is not accounted for during tracing, assigning indices, and finding importance.

Here's a minimal example with a simple model

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv1d(in_channels=96, out_channels=24, kernel_size=3, padding=1, dilation=1),
            nn.GLU(1),
            nn.Conv1d(in_channels=12, out_channels=48, kernel_size=3, padding=2, dilation=2),
        )

    def forward(self, x):
        return self.layer(x)

model = MyModel()
example_inputs = torch.randn(1, 96, 100)

I then prune using GroupNormPruner

imp = tp.importance.GroupNormImportance()
pruner = tp.pruner.GroupNormPruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=1,
    pruning_ratio=0.5,
    ignored_layers=[],
)
pruner.step()

This gives an index out of bounds error

Exception has occurred: IndexError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
index 12 is out of bounds for dimension 0 with size 12
  File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 205, in __call__
    local_imp = local_imp[idxs]
  File "/masked_path/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 247, in estimate_importance
    return self.importance(group)
  File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 362, in prune_local
    imp = self.estimate_importance(group)
  File "/masked_path/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 228, in step
    for group in pruning_method():
  File "masked_path/my_file.py", line 39, in <module>
    pruner.step()
  File "/masked_path/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/masked_path/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
IndexError: index 12 is out of bounds for dimension 0 with size 12

Note that this error is raised when returning the group, so setting interative=True in pruner step does not help.

@janthmueller
Copy link

GLU is currently not supported, so it's treated as an element-wise operation. However, since split is supported, you can create your own GLU operation like this:

class CustomGLU(nn.Module):
    def __init__(self, dim=1):
        super(CustomGLU, self).__init__()
        self.dim = dim

    def forward(self, x):
        first_half, second_half = torch.split(x, x.size(self.dim)//2, dim=self.dim)
        return first_half * torch.sigmoid(second_half)

If you don't use it extensively, the performance degradation shouldn't be significant.

@saravanabalagi
Copy link
Author

Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.

@janthmueller
Copy link

janthmueller commented Apr 15, 2024

Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.

After running the get_pruning_group method within the prune_local function of the MetaPruner class, you might notice that the group containing the first layer appears to have double the number of indices. This likely occurs to prevent shape mismatch errors. However, with a pruning ratio of 0.5, attempting to prune the entire output of the first layer becomes impossible. This is because a group is ignored for pruning if all its filters or channels are pruned, resulting in nothing being pruned in your case.

To accommodate this scenario, it's crucial to apply a targeted adjustment before gathering the pruning_idxs. Specifically, for groups involving the custom glu operation, a workaround involves halving the number of pruned indices (n_pruned) for the affected group. This ensures that the pruning process correctly reflects the intended proportion.

To implement this adjustment, insert the following code snippet before collecting pruning_idxs within both the prune_local and prune_global methods:

for dep, _ in group:
    if isinstance(dep.target.module, ops._SplitOp):
        n_pruned = n_pruned // 2
        break

By incorporating this adjustment, the pruning mechanism can appropriately handle scenarios involving the custom glu operation, ensuring accurate pruning outcomes.

I think it might be best to fix this for all possible scenarios including a split, maybe similar to _is_attn_group with a _is_split_group check @VainF.

@saravanabalagi
Copy link
Author

Great, thanks for the workaround and the explanation!

It would be great to have this merged such that the lib works directly on GLU!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants