Skip to content

Help to speed up code #20932

Answered by jakevdp
JohannesEsslinger asked this question in Q&A
Apr 25, 2024 · 1 comments · 1 reply
Discussion options

You must be logged in to vote

Hi, thanks for the question! Unfortunately, this represents somewhat of a worst case when it comes to efficient computation with JAX: you're looping over dispatches of many small functions, each of which operates on differently-sized arrays. As you discovered, vmap and scan won't help in this case. Further, jit won't be much help either, because it will re-compile the kernel each time it encounters a differently-sized input.

A common strategy in a case like this is to re-express each step in terms of operations over uniformly-shaped inputs (with padding). This looks like it may be possible in your case, because you're essentially just sampling from the input arrays at each step. You'd hav…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@JohannesEsslinger
Comment options

Answer selected by JohannesEsslinger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants