Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blocked at implementing torch port of mx/distributions/mixture.py #3160

Open
pbruneau opened this issue Apr 10, 2024 · 1 comment
Open

Blocked at implementing torch port of mx/distributions/mixture.py #3160

pbruneau opened this issue Apr 10, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@pbruneau
Copy link

Description

I am currently trying to port gluonts.mx.distributions.mixture.py to torch, aiming at using the ported MixtureDistributionOutput with NormalOutput components and DeepAR. I came up with the following implementation: https://gist.github.com/pbruneau/3c9d62f694c50ead8da7adf50014d13a

Basically I focused on implementing the methods that looked essential in the context of DeepAR. My debugging sessions (mixture of 3 NormalOutput components, and private data with 3 dynamic real features) seem to show that it works correctly (i.e., parameters associated to the 3 components seem to fit independently). However in terms of performance, on a (private) benchmark where the MXNet version of DeepAR gets significant performance boost with 3 Gaussian output components versus a single component, my PyTorch version seems to be only on par with a single Gaussian: in other words, I currently don't do better that using NormalOutput() as my output distribution.

I'm quite at a loss as to where to investigate next... Iso debugging MXNet and PyTorch versions of DeepAR is not trivial at all, as they operate quite differently. I could do with any suggestion (some important feature/method missing, a blatant mistake, where to investigate next, ways to iso-debug with/or some appropriate public benchmark)! If I can get this on rails, I would be more than happy to issue a pull request about it.

@pbruneau pbruneau added the enhancement New feature or request label Apr 10, 2024
@pbruneau pbruneau changed the title Blocked at implementing torch port of mx.distributions.mixture.py Blocked at implementing torch port of mx/distributions/mixture.py Apr 10, 2024
@kashif
Copy link
Contributor

kashif commented Apr 22, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants