Skip to content

yklcs/jaxsplat

Repository files navigation

jaxsplat

Documentation Status

A port of 3D Gaussian Splatting to JAX. Fully differentiable, CUDA accelerated.

Read documentation

Installation

Requires a working CUDA toolchain to install. Simply pip installing directly from source should build and install jaxsplat:

$ python -m venv venv && . venv/bin/activate
$ pip install git+https://github.com/yklcs/jaxsplat

Usage

The primary function of this library is jaxsplat.render:

img = jaxsplat.render(
    means3d,   # jax.Array (N, 3)
    scales,    # jax.Array (N, 3)
    quats,     # jax.Array (N, 4) normalized
    colors,    # jax.Array (N, 3)
    opacities, # jax.Array (N, 1)
    viewmat=viewmat,         # jax.Array (4, 4)
    background=background,   # jax.Array (3,)
    img_shape=img_shape,     # tuple[int, int] = (H, W)
    f=f,                     # tuple[float, float] = (fx, fy)
    c=c,                     # tuple[int, int] = (cx, cy)
    glob_scale=glob_scale,   # float
    clip_thresh=clip_thresh, # float
    block_size=block_size,   # int <= 16
)

The rendered output is differentiable w.r.t. means3d, scales, quats, colors, and opacities.

Alternatively, jaxsplat.project projects 3D Gaussians to 2D, and jaxsplat.rasterize sorts and rasterizes 2D Gaussians. jaxsplat.render successively calls jaxsplat.project and jaxsplat.rasterize under the hood.

Examples

See /examples for examples. These can be ran like the following:

$ python -m venv venv && . venv/bin/activate
$ pip install -r examples/requirements.txt

# Train Gaussians on a single image
$ python -m examples.single_image input.png

Method

We use modified versions of gsplat's kernels. The original INRIA implementation uses a custom license and contains dynamically shaped tensors which are harder to port to JAX/XLA.