Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True. #662

Open
lucyanddarlin opened this issue Apr 5, 2023 · 3 comments
Labels
question Further information is requested

Comments

@lucyanddarlin
Copy link

lucyanddarlin commented Apr 5, 2023

I ran /asteroid-master/egs/musdb18/X-UMX/run.sh, but got the error: RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.

I try to set the return_complex=True in x_umx.py :

 stft_f = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.n_hop,
            window=self.window,
            center=self.center,
            normalized=False,
            onesided=True,
            pad_mode="reflect",
            return_complex=True,
        )

but it didn't work...could someone tell me how to solve it? Thank u so much!

@lucyanddarlin lucyanddarlin added the question Further information is requested label Apr 5, 2023
@lucyanddarlin
Copy link
Author

lucyanddarlin commented Apr 5, 2023

here is the detail:

❯ /bin/zsh /Volumes/noEntry/study/asteroid-master/egs/musdb18/X-UMX/run.sh
Results from the following experiment will be stored in exp/train_xumx_d727eb8a
Stage 1: Training
101it [00:00, 705.80it/s]
0it [00:00, ?it/s]train_dataset <asteroid.data.musdb18_dataset.MUSDB18Dataset object at 0x14ea54820>
101it [00:00, 27558.20it/s]
valid_dataset <asteroid.data.musdb18_dataset.MUSDB18Dataset object at 0x14ea54a30>
Compute dataset statistics:   0%|                                                                                                                                                                         | 0/86 [00:00<?, ?it/s]/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/functional.py:641: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/SpectralOps.cpp:867.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
Compute dataset statistics: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86/86 [01:34<00:00,  1.09s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:201: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------


  | Name      | Type            | Params
----------------------------------------------
0 | model     | XUMX            | 35.6 M
1 | loss_func | MultiDomainLoss | 4.1 K 
----------------------------------------------
35.6 M    Trainable params
8.2 K     Non-trainable params
35.6 M    Total params
142.326   Total estimated model params size (MB)
Combination Loss: True
Multi Domain Loss: True, scaling parameter for time-domain loss=10.0
Sanity Checking: 0it [00:00, ?it/s]/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 498, in <module>
    main(arg_dic, plain_args)
  File "train.py", line 465, in main
    trainer.fit(system)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 337, in validation_step
    return self.model(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1113, in _run_ddp_forward
    return module_to_run(*inputs, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 102, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "train.py", line 356, in validation_step
    loss_tmp += self.common_step(batch_tmp, batch_nb, train=False)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/engine/system.py", line 101, in common_step
    est_targets = self(inputs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/engine/system.py", line 73, in forward
    return self.model(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/models/x_umx.py", line 169, in forward
    time_signals = self.decoder(spec, ang)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/models/x_umx.py", line 401, in forward
    wav = torch.istft(
RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.

@DavidDiazGuerra
Copy link
Contributor

It seems like the newer versions of Pytorch have made some changes to the torch.stft and torch.istft functions. I've just run through the same issue and I think I could fix it by doing 'x = torch.view_as_complex(x)' just before calling torch.istft in the line that is raising the error.

Btw, you can also get rid of the deprecation warning you're getting by changing return_complex to True in the call to torch.stft and then doing stft_f = torch.view_as_real(stft_f) just after it.

@r-sawata
Copy link
Contributor

Thank you so much, @DavidDiazGuerra!

As he said, this was caused by the mismatch between old and new pytorch versions. If the PR #684 will be accepted, then this problem should be resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants