Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use pretrained models #201

Open
expectopatronum opened this issue Jan 20, 2021 · 2 comments
Open

How to use pretrained models #201

expectopatronum opened this issue Jan 20, 2021 · 2 comments
Labels

Comments

@expectopatronum
Copy link

You question here
I don't understand how the pretrained models (e.g. musdb+slakhv0_TG3EvX6.pth) should be used. What is the model class for each of the models (I can only see that the first two should be Deep Clustering models and if I recall correctly I managed to use them some time ago), but I can't figure out what model types the other models are.

Could you point me to the relevant documentation or provide some usage examples please?

What you tried

I tried loading the model in the following ways:

model = DeepClustering(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepAudioEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))
model = DeepMaskEstimation(signal, os.path.join(model_dir, 'musdb+slakhv0_TG3EvX6.pth'))

but each of them gives another error when calling model.run().

Also the usage description in the External File Zoo seems outdated:

import nussl
nussl.utils.print_available_audio_files()

utils should be efz_utils, same for model_path = nussl.utils.download_trained_model('example.model')

Thanks a lot & best regards
Verena

@abugler
Copy link
Collaborator

abugler commented Jan 24, 2021

Hi Verena,

Sorry for the late response, but I can help you use the pretrained models. The models that are 100.00 MiB large were used the the paper: https://arxiv.org/abs/2010.12650, and follow the same general pattern to use. All of these are recurrent deep clustering models: https://arxiv.org/pdf/1508.04306.pdf

The following code should load the model.

from nussl.ml import SeparationModel
import torch

checkpoint = torch.load("path/to/checkpoint.pth")
model = SeparationModel(checkpoint["config"])
model.load_state_dict(checkpoint["state_dict"])

You should now have a loaded model! Note that this model can separate STFTs with a window length of 512 samples.

Let me know if you have any other questions. Also, this page in the documentation will be useful in handling this model: https://nussl.github.io/docs/tutorials/training.html

EDIT: fixed code

@expectopatronum
Copy link
Author

Hi!
Thanks a lot for your response and the link to your paper, I wasn't aware of it. I am looking forward to reading it!

What is the nussl version that I should be using? I am using 1.1.3 (I think that was the most current one when I downloaded the models 3 weeks ago), I am getting the following error:

ValueError: Expected keys ['connections', 'modules', 'name', 'output'], got ['connections', 'modules', 'output']

in the line model = SeparationModel(checkpoint["config"]).

Thanks and best regards
Verena

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants