Challenges porting gaussian splatting to Jax #19124
Replies: 3 comments
-
Thanks for the question! On the first point, if I understand correctly, JAX does not provide any equivalent API. All arrays in JAX must be statically-sized. In practice we've found that most use-cases that seem to require dynamic shapes can be re-expressed in terms of static shapes with some thought. Perhaps this can as well? Regarding variables stored in the runtime context for the backward pass: this is part of the vector-jacobian product (VJP) code. For example, you can see the documentation of Does that help unblock you? |
Beta Was this translation helpful? Give feedback.
-
Thanks! That reverse mode tip covers the storage of the intermediates. I'll have to see if I can replicate this memory precomputing. It's rather tricky but should be deterministic. https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/main/cuda_rasterizer/rasterizer_impl.cu#L155 So as I understand, once I compute that, I should initialize that array in the lowering function and pass it to the custom call? |
Beta Was this translation helpful? Give feedback.
-
@peabody124 Do you have any progress on this you can share? Very much interested in bringing Gaussian splatting to JAX. |
Beta Was this translation helpful? Give feedback.
-
Love Jax and Equinox, so want to use it for Gaussian Splatting. I've been trying to port the CUDA code for it over to Jax. I've got a lot of the pieces connected together, but am running into two key barriers. I suspect they can be overcome, but just not finding documentation around what I need:
Any suggestions or pointers?
Beta Was this translation helpful? Give feedback.
All reactions