Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argfirst support for bool array or int array with (0, 1) values #16638

Open
1 of 2 tasks
chaokunyang opened this issue Jul 6, 2023 · 4 comments
Open
1 of 2 tasks

argfirst support for bool array or int array with (0, 1) values #16638

chaokunyang opened this issue Jul 6, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@chaokunyang
Copy link

I'm using jax to implement a case when operator, the implement use jax.numpy.argmax to find first true condition:

@jax.jit
def _jax_execute(conditions, scalar_array):
    xnp = jnp
    else_condition = xnp.broadcast_to(1, len(conditions[0]))
    conditions.append(else_condition)
    conditions = xnp.column_stack(conditions)
    results = xnp.array(scalar_array)
    index = xnp.nanargmax(conditions, axis=1)
    tensor = results[index]
    return tensor

_jax_execute([jnp.arange(100000) > 50000, jnp.arange(100000) < 40000], [0, 1, 2])
_jax_execute([jnp.arange(5) > 3, jnp.arange(5) < 2], [0, 1, 2])

It works, but with the number of condition branchs grows, the performance staret to degrade. Is there any possibilities to implement a new argfirst operand in jax using lax api, which stops iteration when it find the first True/1 in a condition row.

I took a look at the jax code, seems argmax are implemented using lax.reduce:

def _compute_argminmax(value_comparator, get_identity,
                       operand, *, index_dtype, axes):
  # value_comparator is either lax.lt (for argmin) or lax.gt
  # get_identity(operand.dtype) is inf for argmin or -inf for argmax
  axis, = axes
  indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
  def reducer_fn(op_val_index, acc_val_index):
    op_val, op_index = op_val_index
    acc_val, acc_index = acc_val_index
    # Pick op_val if Lt (for argmin) or if NaN
    pick_op_val = bitwise_or(value_comparator(op_val, acc_val),
                             ne(op_val, op_val))
    # If x and y are not NaN and x = y, then pick the first
    pick_op_index = bitwise_or(pick_op_val,
                               bitwise_and(eq(op_val, acc_val),
                                           lt(op_index, acc_index)))
    return (select(pick_op_val, op_val, acc_val),
            select(pick_op_index, op_index, acc_index))
  res = reduce([operand, indices],
               [get_identity(operand.dtype), np.array(0, index_dtype)],
               reducer_fn,
               axes)
  return res[1]

Is is possible to implement for argfirst, but with early stopping using similar mechanism?

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.
@chaokunyang chaokunyang added the enhancement New feature or request label Jul 6, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 6, 2023

Thanks for the question! I don't think this would be feasible, because I don't believe XLA has any way to efficiently express a short-circuiting reduction operation. By construction, XLA:Reduce requires the reduction to be monoidal – this is because it does not simply scan linearly over the values, but rather does a sort of map-reduction in parallel under the hood. Given this, even if an early-return reduction were expressible in the IR, I don't think that ending the reduction early would improve your performance much in practice.

@jakevdp jakevdp self-assigned this Jul 6, 2023
@chaokunyang
Copy link
Author

Hi @jakevdp , thanks for you reply. Is it possible to implement this operator using a XLA custom call? so the whole computation graph can still be traced by jax and optimized by XLA without breaking it into mutiple graphs when used with other operators.

@shoyer
Copy link
Member

shoyer commented Jul 10, 2023

I would be surprised if a hypothetical short-cutting argfirst sped-up actual JAX programs. Even in NumPy (see numpy/numpy#2269) the benefits are somewhat questionable, because you have to create the array first, and that's usually going to be comparably (or more) expensive than the O(N) pass to compute argmax. On accelerators, the situation is even worse, because accelerators are really slow at branching computation. So it would be certainly be helpful to see something closer to a complete example that you want to speed-up here.

That said, if you know your array only has 0/1 values I wouldn't be surprised if you could use clever tricks to implement argfirst via different operations that turns out to be faster than Reduce on small arrays, at least on accelerators. E.g., you could do a cumsum implemented via matrix-multiplication with an upper-triangular matrix, and then use jnp.searchsorted with method='compare_all'.

@jakevdp jakevdp removed their assignment Jul 11, 2023
@chaokunyang
Copy link
Author

We are implementing a rule computing engine use jax, XLA and GPU/CPU, the computing contains thounds of casewhen/ifelse operators, the biggest casewhen has 800 branchs. For example, bucketing a continuous numeric array into discrete array based on very complex condition composed of other array computing.

With our abstraction, we end up lowering all casewhen and other numerics computation and fuse them into a jax/xla graph. The speedup is very huge, but casewhen still have bottlenecks.

All case when condition are bool tensor, which can be taken as 0/1 values. Another case is sql coalesce, which finds first not null value, in such cases, we use argmax for null mask matrix.

We do try using matrix-multiplication, such as:

import functools,jax
from typing import Union, List, Tuple

@functools.partial(jax.jit, static_argnums=1)
def _bool_tensors_nanargmax_by_malmul(
        inputs: Union[List, Tuple, jax.Array], row_all_values_false_possible=True):
    input_is_matrix = isinstance(inputs, jax.Array)
    if input_is_matrix:
        num_conds = inputs.shape[1]
        matrix = inputs
    else:
        num_conds = len(inputs)
        # column_stack inside will make the fused execution more efficiently by at most 6x.
        matrix = jnp.column_stack(inputs)
    radix_tensor = jnp.power(2, jnp.arange(num_conds - 1, -1, -1))
    matrix = jnp.nan_to_num(matrix, nan=0)
    prod = matrix @ radix_tensor
    # Since the result are close to xxx.999, no 0.5 round occurs.
    index_float = jnp.round(jnp.log2(prod))
    if row_all_values_false_possible:
        # if all values in a row is 0, then log2 of value will be -inf
        index_float = jnp.where(jnp.isinf(index_float), num_conds - 1, index_float)
    index = jnp.abs(index_float - (num_conds - 1))
    return index

inputs = [(jnp.arange(100000) > 50000).astype(jnp.int8) for i in range(4)]
%timeit _bool_tensors_nanargmax_by_malmul(inputs)

@functools.partial(jax.jit, static_argnums=1)
def bool_tensors_nanargmax(inputs: Union[List, Tuple], row_all_values_false_possible=True):
    num_conds = len(inputs)
    matrix = jnp.column_stack(inputs)
    return jnp.nanargmax(matrix, axis=1)
%timeit bool_tensors_nanargmax(inputs)

It's 4 times faster, but the float precision is not reliable, the result can be wrong sometimes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants