Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Porting model another device (CPU/GPU) #274

Open
satabios opened this issue Dec 13, 2023 · 4 comments
Open

Bug: Porting model another device (CPU/GPU) #274

satabios opened this issue Dec 13, 2023 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@satabios
Copy link

satabios commented Dec 13, 2023

  • snntorch version: 0.7.0
  • Python version: 3.10.13
  • Operating System: Ubuntu x64

Description

When I load a pre-trained model and push it from GPU->CPU or the vice versa. Certain variables are still projecting in the original device and are not getting pushed to the destination device.

Specifically the "mem" variable under snntorch._neurons.leaky.Leaky

snn_pretrained_model_path = "snn_model.pth"
snn_model.load_state_dict(torch.load(snn_pretrained_model_path))  
snn_model.to("cpu") # or "gpu"

What I Did

As a workaround, I deliberately iterate over all the model to find such instances of leaky and push them to the destination device.

if isinstance(layer, nn.Sequential):
    for layer_id in range(len(original_dense_model)):
        layer = original_dense_model[layer_id]
        if isinstance((layer), snntorch._neurons.leaky.Leaky):
            layer.mem = layer.mem.to("cpu") # or "gpu" depending on the destination device
else:
    for internal_layer in model.modules():
        if isinstance((internal_layer), snntorch._neurons.leaky.Leaky):
            internal_layer.mem = internal_layer.mem.to("cpu")
@ahenkes1 ahenkes1 added the bug Something isn't working label Jan 5, 2024
@ahenkes1
Copy link
Collaborator

ahenkes1 commented Jan 5, 2024

@satabios , have you tried the following steps?
Saving and loading models across devices in PyTorch

@satabios
Copy link
Author

satabios commented Jan 5, 2024

I followed Saving and loading models across devices in PyTorch.

The issue isn't during saving or reloading. However, it is with porting from one device to another, as mentioned above, Certain members of the model aren't getting transferred intrinsically.

@ahenkes1
Copy link
Collaborator

ahenkes1 commented Jan 5, 2024

So the issue is there with newly created models? Without the saving/loading.

@satabios
Copy link
Author

satabios commented Jan 7, 2024

Either way, the bug persists. When the model is in memory and is prompted to transfer to a different device (say from GPU to CPU or vice versa) or when the model is loaded from a file. The porting causes the variables to be struck in the original device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants