New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
argfirst support for bool array or int array with (0, 1) values #16638
Comments
Thanks for the question! I don't think this would be feasible, because I don't believe XLA has any way to efficiently express a short-circuiting reduction operation. By construction, XLA:Reduce requires the reduction to be monoidal – this is because it does not simply scan linearly over the values, but rather does a sort of map-reduction in parallel under the hood. Given this, even if an early-return reduction were expressible in the IR, I don't think that ending the reduction early would improve your performance much in practice. |
Hi @jakevdp , thanks for you reply. Is it possible to implement this operator using a |
I would be surprised if a hypothetical short-cutting That said, if you know your array only has 0/1 values I wouldn't be surprised if you could use clever tricks to implement |
We are implementing a rule computing engine use jax, XLA and GPU/CPU, the computing contains thounds of casewhen/ifelse operators, the biggest casewhen has 800 branchs. For example, bucketing a continuous numeric array into discrete array based on very complex condition composed of other array computing. With our abstraction, we end up lowering all casewhen and other numerics computation and fuse them into a jax/xla graph. The speedup is very huge, but casewhen still have bottlenecks. All case when condition are bool tensor, which can be taken as 0/1 values. Another case is We do try using matrix-multiplication, such as: import functools,jax
from typing import Union, List, Tuple
@functools.partial(jax.jit, static_argnums=1)
def _bool_tensors_nanargmax_by_malmul(
inputs: Union[List, Tuple, jax.Array], row_all_values_false_possible=True):
input_is_matrix = isinstance(inputs, jax.Array)
if input_is_matrix:
num_conds = inputs.shape[1]
matrix = inputs
else:
num_conds = len(inputs)
# column_stack inside will make the fused execution more efficiently by at most 6x.
matrix = jnp.column_stack(inputs)
radix_tensor = jnp.power(2, jnp.arange(num_conds - 1, -1, -1))
matrix = jnp.nan_to_num(matrix, nan=0)
prod = matrix @ radix_tensor
# Since the result are close to xxx.999, no 0.5 round occurs.
index_float = jnp.round(jnp.log2(prod))
if row_all_values_false_possible:
# if all values in a row is 0, then log2 of value will be -inf
index_float = jnp.where(jnp.isinf(index_float), num_conds - 1, index_float)
index = jnp.abs(index_float - (num_conds - 1))
return index
inputs = [(jnp.arange(100000) > 50000).astype(jnp.int8) for i in range(4)]
%timeit _bool_tensors_nanargmax_by_malmul(inputs)
@functools.partial(jax.jit, static_argnums=1)
def bool_tensors_nanargmax(inputs: Union[List, Tuple], row_all_values_false_possible=True):
num_conds = len(inputs)
matrix = jnp.column_stack(inputs)
return jnp.nanargmax(matrix, axis=1)
%timeit bool_tensors_nanargmax(inputs) It's 4 times faster, but the float precision is not reliable, the result can be wrong sometimes. |
I'm using jax to implement a case when operator, the implement use jax.numpy.argmax to find first true condition:
It works, but with the number of condition branchs grows, the performance staret to degrade. Is there any possibilities to implement a new
argfirst
operand in jax usinglax
api, which stops iteration when it find the firstTrue
/1
in a condition row.I took a look at the jax code, seems
argmax
are implemented usinglax.reduce
:Is is possible to implement for
argfirst
, but with early stopping using similar mechanism?The text was updated successfully, but these errors were encountered: