You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Assume I am given a 2-d array of data x and two 1-d index arrays i and j containing the index of a column in each row, all sharded along the leading axis.
I would like to update the element given by i with the element given by j in each row of x.
I implemented this using straightforward numpy array indexing (f1 and f2).
When compiled this does an all-gather of the data and then indexes into it.
I was wondering if there is a good way to avoid the unnecessary collective communication without having to resort to use masks (f3) or shard_map (f4)?
Example:
importosos.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=2"fromfunctoolsimportpartialimportnumpyasnpimportjaximportjax.numpyasjnpfromjax.shardingimportPositionalSharding, Mesh, PartitionSpecasPfromjax.experimental.shard_mapimportshard_mappos_sharding=PositionalSharding(jax.devices())
mesh=Mesh(jax.devices(), axis_names='i')
x=np.random.uniform(size=(jax.device_count() *2, 3))
i=np.random.randint(0, x.shape[1], len(x))
j=np.random.randint(0, x.shape[1], len(x))
x=jax.lax.with_sharding_constraint(x, pos_sharding.reshape(-1, 1))
i=jax.lax.with_sharding_constraint(i, pos_sharding)
j=jax.lax.with_sharding_constraint(j, pos_sharding)
@jax.jit@jax.vmapdeff1(x, i, j):
returnx.at[i].set(x[j])
@jax.jitdeff2(x, i, j):
a=jnp.arange(len(x))
returnx.at[a, i].set(x[a, j])
# masks@jax.jit@jax.vmapdeff3(x, i, j):
a=jnp.arange(len(x))
maski=i==amaskj=j==areturnjnp.where(maski, x @ maskj, x)
# shard_map@jax.jit@partial(shard_map, mesh=mesh, in_specs=(P('i'), P('i'), P('i')), out_specs=P('i'))deff4(x, i, j):
returnf1(x, i, j)
collective_ops= ['all-reduce', 'collective-permute', 'all-gather', 'all-to-all', 'reduce-scatter']
defany_in(ops, txt):
returnany(map(lambdax: xintxt, ops))
print(any_in(collective_ops, f1.lower(x, i, j).compile().as_text())) # Trueprint(any_in(collective_ops, f2.lower(x, i, j).compile().as_text())) # Trueprint(any_in(collective_ops, f3.lower(x, i, j).compile().as_text())) # Falseprint(any_in(collective_ops, f4.lower(x, i, j).compile().as_text())) # False
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Assume I am given a 2-d array of data
x
and two 1-d index arraysi
andj
containing the index of a column in each row, all sharded along the leading axis.I would like to update the element given by
i
with the element given byj
in each row ofx
.I implemented this using straightforward numpy array indexing (
f1
andf2
).When compiled this does an all-gather of the data and then indexes into it.
I was wondering if there is a good way to avoid the unnecessary collective communication without having to resort to use masks (
f3
) or shard_map (f4
)?Example:
Beta Was this translation helpful? Give feedback.
All reactions