Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't run the deep learning based source separation model #230

Open
HCIMaker opened this issue Jan 18, 2022 · 3 comments
Open

Can't run the deep learning based source separation model #230

HCIMaker opened this issue Jan 18, 2022 · 3 comments
Labels

Comments

@HCIMaker
Copy link

BEFORE POSTING A BUG REPORT Please look through existing issues (both open and closed) to see if it's already been reported or fixed!

Describe the bug
When I am using the deep mask estimation machine learning model to try the source separation, I get the following error at this step:
separator = nussl.separation.deep.DeepMaskEstimation(
audio_signal, mask_type='soft', model_path=model_path)

ERROR:
Error(s) in loading state_dict for SeparationModel:
Missing key(s) in state_dict: "normalization.batch_norm.weight", "normalization.batch_norm.bias", "normalization.batch_norm.running_mean", "normalization.batch_norm.running_var", "recurrent_stack.rnn.weight_ih_l0", "recurrent_stack.rnn.weight_hh_l0", "recurrent_stack.rnn.bias_ih_l0", "recurrent_stack.rnn.bias_hh_l0", "recurrent_stack.rnn.weight_ih_l0_reverse", "recurrent_stack.rnn.weight_hh_l0_reverse", "recurrent_stack.rnn.bias_ih_l0_reverse", "recurrent_stack.rnn.bias_hh_l0_reverse", "recurrent_stack.rnn.weight_ih_l1", "recurrent_stack.rnn.weight_hh_l1", "recurrent_stack.rnn.bias_ih_l1", "recurrent_stack.rnn.bias_hh_l1", "recurrent_stack.rnn.weight_ih_l1_reverse", "recurrent_stack.rnn.weight_hh_l1_reverse", "recurrent_stack.rnn.bias_ih_l1_reverse", "recurrent_stack.rnn.bias_hh_l1_reverse", "recurrent_stack.rnn.weight_ih_l2", "recurrent_stack.rnn.weight_hh_l2", "recurrent_stack.rnn.bias_ih_l2", "recurrent_stack.rnn.bias_hh_l2", "recurrent_stack.rnn.weight_ih_l2_reverse", "recurrent_stack.rnn.weight_hh_l2_reverse", "recurrent_stack.rnn.bias_ih_l2_reverse", "recurrent_stack.rnn.bias_hh_l2_reverse", "recurrent_stack.rnn.weight_ih_l3", "recurrent_stack.rnn.weight_hh_l3", "recurrent_stack.rnn.bias_ih_l3", "recurrent_stack.rnn.bias_hh_l3", "recurrent_stack.rnn.weight_ih_l3_reverse", "recurrent_stack.rnn.weight_hh_l3_reverse", "recurrent_stack.rnn.bias_ih_l3_reverse", "recurrent_stack.rnn.bias_hh_l3_reverse", "mask.linear.weight", "mask.linear.bias".

Steps To Reproduce

Expected behavior
By looking at the tutorial, it should give the result of separated audio.

What did happen
A clear and concise description of what did happen.

Audio output
If applicable, please link to audio examples that you uploaded to help us diagnose the issue.

Screenshots
If applicable, add screenshots to help explain your problem.

Software versions*

Additional context
Add any other context about the problem here.

@HCIMaker HCIMaker added the bug label Jan 18, 2022
@ethman
Copy link
Collaborator

ethman commented Jan 18, 2022

Have you trained a model? In the code snipped you provided:

separator = nussl.separation.deep.DeepMaskEstimation(audio_signal, mask_type='soft', model_path=model_path)

it is assumed that you have a trained model at the path specified by the variable model_path. It looks like whatever file is at that path is corrupted somehow. As the error message says, pytorch is looking for specific parts of the model (looks like the Batch Normalization layer, the Rnn layers, and the Mask Layer) but can't find them in the file that you provided.

@HCIMaker
Copy link
Author

Hi Ethan:
I am sorry that the complete code does not attach. Here is my complete code: (Almost same as: https://nussl.github.io/docs/examples/deep/deep_mask_estimation.html)

import nussl
import matplotlib.pyplot as plt
import time
import warnings

warnings.filterwarnings("ignore")
start_time = time.time()

def visualize_and_embed(sources):
    plt.figure(figsize=(10, 6))
    plt.subplot(211)
    nussl.utils.visualize_sources_as_masks(sources,
        y_axis='linear', db_cutoff=-40, alpha_amount=2.0)
    plt.subplot(212)
    nussl.utils.visualize_sources_as_waveform(
        sources, show_legend=False)
    plt.show()
    nussl.play_utils.multitrack(sources)

model_path = nussl.efz_utils.download_trained_model(
    'mask-inference-wsj2mix-model-v1.pth')
audio_path = nussl.efz_utils.download_audio_file(
    'wsj_speech_mixture_ViCfBJj.mp3')
audio_signal = nussl.AudioSignal(audio_path)

saved_model = torch.load(model_path,map_location=torch.device('cpu'))
saved_model["nussl_version"] = "1.1.9" # Due to the nussl_version keyword missing error
torch.save(saved_model, "C:/Users/USERNAME/.nussl/models/mask-inference-wsj2mix-model-v1.pth")

separator = nussl.separation.deep.DeepMaskEstimation(
    audio_signal, mask_type='soft', model_path=model_path)
estimates = separator()

estimates = {
    f'Speaker {i}': e for i, e in enumerate(estimates)
}

visualize_and_embed(estimates)

Whether the downloaded model is pre-trained ? Or I have to train it by myself? Thx!

@ethman
Copy link
Collaborator

ethman commented Jan 18, 2022

Your problem is here: saved_model["nussl_version"] = "1.1.9" # Due to the nussl_version keyword missing error You're trying to use an old model with a new version of nussl. That won't work haha. That's what the error was trying to tell you before you added that line. You bypassed the original error when you added that line, and the old model won't work with a newer version of nussl.

Unfortunately, we don't have any pre-trained models available currently. If you want to use one with nussl, you will have to train your own. Alternatively, there are pre-trained models available from asteroid. We don't currently have the bandwidth to provide a nice suite of available models. Sorry and best of luck!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants