New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements Ragged Dot API #20940
base: main
Are you sure you want to change the base?
Implements Ragged Dot API #20940
Conversation
PiperOrigin-RevId: 623910451
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Assigning @mattjj, who has thought a bit about how to support ragged operations more broadly in JAX. |
jax/_src/lax/lax.py
Outdated
precision=precision, | ||
preferred_element_type=preferred_element_type, | ||
) | ||
return np.sum(result, axis=0, dtype=result.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're using ordinary numpy and this is just a reference implementation, can you put it in lax_reference.py instead?
If this is meant to be a lax (i.e. HLO-ish) implementation, can you avoid calling ordinary numpy and just call jnp or lax functions instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given we already have a reference implementation in lax_reference.py, I'm not sure what this one is for (unless it's meant to be turned into one that doesn't call ordinary numpy).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is meant to be a lax implementation that is lowerable to HLO. I modified the implementation so that no numpy functions are used.
I am quite surprised that when using numpy (np.cumsum and np.sum) in the lax.ragged_dot, it is still lowered to HLO as expected, yielding:
HloModule jit_wrapped_fun
cumsum.12 {
ROOT Arg_0.13 = s32[1]{0} parameter(0)
}
_cumulative_reduction.14 {
Arg_0.15 = s32[1]{0} parameter(0)
ROOT call.16 = s32[1]{0} call(Arg_0.15), to_apply=cumsum.12
}
region_0.25 {
Arg_0.26 = f32[] parameter(0)
Arg_1.27 = f32[] parameter(1)
ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27)
}
ENTRY main.31 {
constant.6 = s32[] constant(0)
broadcast.7 = s32[1,5,4]{2,1,0} broadcast(constant.6), dimensions={}
iota.10 = s32[5]{0} iota(), iota_dimension=0
broadcast.11 = s32[1,5,4]{2,1,0} broadcast(iota.10), dimensions={1}
compare.19 = pred[1,5,4]{2,1,0} compare(broadcast.7, broadcast.11), direction=LE
Arg_2.3 = s32[1]{0} parameter(2)
call.17 = s32[1]{0} call(Arg_2.3), to_apply=_cumulative_reduction.14
broadcast.18 = s32[1,5,4]{2,1,0} broadcast(call.17), dimensions={0}
compare.20 = pred[1,5,4]{2,1,0} compare(broadcast.11, broadcast.18), direction=LT
and.21 = pred[1,5,4]{2,1,0} and(compare.19, compare.20)
Arg_0.1 = bf16[5,4]{1,0} parameter(0)
reshape.9 = bf16[1,5,4]{2,1,0} reshape(Arg_0.1)
constant.4 = bf16[] constant(0)
broadcast.5 = bf16[1,5,4]{2,1,0} broadcast(constant.4), dimensions={}
select.22 = bf16[1,5,4]{2,1,0} select(and.21, reshape.9, broadcast.5)
Arg_1.2 = bf16[1,4,3]{2,1,0} parameter(1)
dot.23 = bf16[1,5,3]{2,1,0} dot(select.22, Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
convert.24 = f32[1,5,3]{2,1,0} convert(dot.23)
constant.8 = f32[] constant(0)
reduce.29 = f32[5,3]{1,0} reduce(convert.24, constant.8), dimensions={0}, to_apply=region_0.25
ROOT convert.30 = bf16[5,3]{1,0} convert(reduce.29)
} // main.31
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am quite surprised that when using numpy (np.cumsum and np.sum) in the lax.ragged_dot, it is still lowered to HLO as expected, yielding:
Yeah interesting! Me too. I haven't read this whole PR yet, but I guess it must be that, at least in the tests you have, the numpy stuff was all trace-time static data, i.e. stuff that could be evaluated while running the Python? In that case, writing original numpy calls is like saying "I want this to be evaluated at Python tracing time, not staged out into the HLO".
But I would expect that in general the sizes involved here would be dynamic, ie not available until runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, all uses of np are removed now. Thank you.
remove uses of numpy in the ragged dot implementation PiperOrigin-RevId: 623910451
|
tests/lax_test.py
Outdated
lhs_shape = (m, k) | ||
rhs_shape = (num_groups, k, n) | ||
def group_sizes(m, num_groups): | ||
ends_no_final = jnp.sort(np.random.choice(m, size=num_groups - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the random choice here use rng
, or otherwise should the numpy RNG being used be seeded first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, switched to using self.rng().choice()
instead.
…ntation and does JAX compilation
PTAL. |
Implemented the ragged_dot as JAX primitive. |
Background
Ragged Dot is a specialized matrix multiplication operation that is commonly used in the context of Mixture of Experts (MoE) models. MoE models are a type of neural network architecture that consists of a collection of independent expert networks, each of which is responsible for processing a specific subset of the input data. In order to determine which expert should process a given input, a routing mechanism is employed. Ragged Dot plays a crucial role in this routing process.
At the linear algebra level, Ragged Dot can be defined as follows:
Given matrices, A (m x k), B (g x k x n), and G (g), where m is the number of input samples, k is the dimensionality of the input features, and g is the number of experts, the Ragged Dot operation produces a matrix C (m x n), where each row of C corresponds to the weighted sum of the corresponding row of A and the columns of B associated with the expert assigned to that input sample.
More formally, the (i, j)-th element of C is computed as follows:
where k ranges over the columns of B associated with the expert assigned to the i-th input sample.
The key characteristic of Ragged Dot is that the rows of A and slices of B in the 0th dimension (of size g) are grouped into disjoint sets, with each set corresponding to an expert. This grouping structure allows the routing mechanism to efficiently assign input samples to the appropriate experts.
Requirements
Arguments:
Results:
Preconditions:
group_sizes = [s_1, ..., s_g2], where s_1 + ... + s_g2 <= m
g1 <= g2 (number of local groups <= number of groups).
If g1 == g2, group_offsets must be either unspecified or explicitly set to [0].
If g1 < g2, group_offsets must contain a single value in [0, g2) where group_offsets[0] + g1 <= g2. In this case, we perform g1 dots, where irrelevant slices of the results remain unchanged.