Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ridding _check_2d #90

Open
lxuechen opened this issue Feb 7, 2021 · 2 comments
Open

Ridding _check_2d #90

lxuechen opened this issue Feb 7, 2021 · 2 comments

Comments

@lxuechen
Copy link
Collaborator

lxuechen commented Feb 7, 2021

I'm proposing to get rid of the 2d shape checks. These were added in #88 as part of the v0.2.4.

These checks are creating a huge barrier for applications that don't have vectorized data naturally. Flattening and unflattening will hurt efficiency tremendously.

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Feb 7, 2021

The reason they're there is that several parts of the internal code -- for example misc.batch_mvp and ForwardSDE.dg_ga_jvp -- implicitly assume there is a single batch dimension.

If we can go through and sort those out then I am in favour of this; I agree this is a wart. Ideally we've be able to have y0 take an arbitrary shape.

Off the top of my head I think the only parts of the code that needs to distinguish batch dimensions from channel dimensions is when creating a default Brownian motion (needing one sample per batch but not one per channel), and ForwardSDE.dg_ga_jvp, so those would need some way of specifying that detail.

In passing, why is flattening/unflattening hurting efficiency? It should be doable just be re-striding the tensor, which is cheap.

@lxuechen
Copy link
Collaborator Author

lxuechen commented Feb 8, 2021

I am completely aware of why they are needed. I will come up with a design doc next weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants