Help to speed up code #20932
-
Hello,
I cannot simply vmap (same for scan) over my number of components, as samples_per_component is different for each component and therefore the resulting shape, e.g.,
I would really appreciate any help! Regards |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, thanks for the question! Unfortunately, this represents somewhat of a worst case when it comes to efficient computation with JAX: you're looping over dispatches of many small functions, each of which operates on differently-sized arrays. As you discovered, A common strategy in a case like this is to re-express each step in terms of operations over uniformly-shaped inputs (with padding). This looks like it may be possible in your case, because you're essentially just sampling from the input arrays at each step. You'd have to modify how If rewriting in terms of padded operations is not possible, then your current approach is probably the best available. Finally, side-note: the way you are splitting your seeds might lead to corrolation between the steps, because you are using for i in range(gmm_state.num_components):
if samples_per_component[i] == 0:
continue
seed_i = jax.random.fold_in(seed, i)
samples.append(sample_from_component_fn(gmm_state, i, samples_per_component[i], seed_i)) |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question! Unfortunately, this represents somewhat of a worst case when it comes to efficient computation with JAX: you're looping over dispatches of many small functions, each of which operates on differently-sized arrays. As you discovered,
vmap
andscan
won't help in this case. Further,jit
won't be much help either, because it will re-compile the kernel each time it encounters a differently-sized input.A common strategy in a case like this is to re-express each step in terms of operations over uniformly-shaped inputs (with padding). This looks like it may be possible in your case, because you're essentially just sampling from the input arrays at each step. You'd hav…