Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX scan loop not compiling when having a single tile_map call #14

Open
balancap opened this issue Aug 22, 2023 · 0 comments
Open

JAX scan loop not compiling when having a single tile_map call #14

balancap opened this issue Aug 22, 2023 · 0 comments
Labels
bug Something isn't working

Comments

@balancap
Copy link
Contributor

balancap commented Aug 22, 2023

There is a weird bug where it seems that the call is forwarded to CPU XLA backend when there is a single tile_map call in a JAX loop.

Minimal reproducer:

from functools import partial

import jax
import numpy as np

from tessellate_ipu import tile_map, tile_put_sharded
from tessellate_ipu.lax import sqrt_inplace_p

data = np.array([1, 2, 3], np.float32)
tiles = (1, 2, 5)

def inner_scan(carry, _):
    return tile_map(jax.lax.sqrt_p, carry), None

@partial(jax.jit, backend="ipu")
def compute_fn(input):
    x = tile_put_sharded(input, tiles)
    return jax.lax.scan(inner_scan, x, None, length=4)

output = compute_fn(data)
@balancap balancap added the bug Something isn't working label Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant