Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruned Yolov8 model not loading? #363

Open
ashray21 opened this issue Apr 1, 2024 · 3 comments
Open

Pruned Yolov8 model not loading? #363

ashray21 opened this issue Apr 1, 2024 · 3 comments

Comments

@ashray21
Copy link

ashray21 commented Apr 1, 2024

@VainF I have trained a custom YOLOv8 model. After training i have successfully pruned the model.

 for name, param in model.model.named_parameters():
        param.requires_grad = True

replace_c2f_with_c2f_v2(model.model)

model.model.eval()
example_inputs = torch.randn(1, 3, 800, 800).to(model.device)
imp = tp.importance.MagnitudeImportance(p=2)  # L2 norm pruning

ignored_layers = []
unwrapped_parameters = []

modules_list = list(model.model.modules())
for i, m in enumerate(modules_list):
    if isinstance(m, (Detect,)):
        ignored_layers.append(m)

iterative_steps = 1  # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model.model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5,  # remove 50% channels
    ignored_layers=ignored_layers,
    unwrapped_parameters=unwrapped_parameters
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
pruner.step()

pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs)
print("Before Pruning: MACs=%f G, #Params=%f G" % (base_macs / 1e9, base_nparams / 1e9))
print("After Pruning: MACs=%f G, #Params=%f G" % (pruned_macs / 1e9, pruned_nparams / 1e9))

After I save pruned model. Is it the correct way to save pruned model?

torch.save(pruner.model, "prune.pt")

After saving model I load model and it showing the following error:

 pruned_model = YOLO("prune.pt")

 AttributeError                            Traceback (most recent call last)
Cell In[88], [line 1](vscode-notebook-cell:?execution_count=88&line=1)
----> [1](vscode-notebook-cell:?execution_count=88&line=1) pruned_model = YOLO("prune.pt")

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94), in Model.__init__(self, model, task)
     [92](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:92)     self._new(model, task)
     [93](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:93) else:
---> [94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94)     self._load(model, task)

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140), in Model._load(self, weights, task)
    [138](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:138) suffix = Path(weights).suffix
    [139](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:139) if suffix == '.pt':
--> [140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140)     self.model, self.ckpt = attempt_load_one_weight(weights)
    [141](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:141)     self.task = self.model.args['task']
    [142](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:142)     self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609), in attempt_load_one_weight(weight, device, inplace, fuse)
    [607](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:607) def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
    [608](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:608)     """Loads a single model weights."""
--> [609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609)     ckpt, weight = torch_safe_load(weight)  # load ckpt
    [610](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:610)     args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))}  # combine model and default args, preferring model args
    [611](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:611)     model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548), in torch_safe_load(weight)
...
   [1413](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1413)         pass
   [1414](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1414) mod_name = load_module_mapping.get(mod_name, mod_name)
-> [1415](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1415) return super().find_class(mod_name, name)

AttributeError: Can't get attribute '__main__' on <module 'builtins' (built-in)>

Also I need to ask is it necessary to train again on a pruned model?

@luoshiyong
Copy link

have you solove this question? i had met same problem

@ashray21
Copy link
Author

ashray21 commented Apr 8, 2024

have you solove this question? i had met same problem

@luoshiyong Not yet. How did you saved your model ?

@CloudRider-pixel
Copy link

Hi,

I managed to use the example demo of Yolov8.
Then I'm able to load and run the pruned model by using :

from ultralytics.nn.tasks import attempt_load_one_weight
model, _ = attempt_load_one_weight(weights)

But for that I need to import the C2f_v2 class as it's not part of ultralytics YOLOV8 :
class C2f_v2(nn.Module):
# CSP Bottleneck with 2 convolutions
def init(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().init()
self.c = int(c2 * e) # hidden channels
self.cv0 = Conv(c1, self.c, 1, 1)
self.cv1 = Conv(c1, self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

def forward(self, x):
    # y = list(self.cv1(x).chunk(2, 1))
    y = [self.cv0(x), self.cv1(x)]
    y.extend(m(y[-1]) for m in self.m)
    return self.cv2(torch.cat(y, 1))

Hope it will help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants