Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Ragged Dot API #20940

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

piotrfilipiuk
Copy link

Background

Ragged Dot is a specialized matrix multiplication operation that is commonly used in the context of Mixture of Experts (MoE) models. MoE models are a type of neural network architecture that consists of a collection of independent expert networks, each of which is responsible for processing a specific subset of the input data. In order to determine which expert should process a given input, a routing mechanism is employed. Ragged Dot plays a crucial role in this routing process.

At the linear algebra level, Ragged Dot can be defined as follows:

Given matrices, A (m x k), B (g x k x n), and G (g), where m is the number of input samples, k is the dimensionality of the input features, and g is the number of experts, the Ragged Dot operation produces a matrix C (m x n), where each row of C corresponds to the weighted sum of the corresponding row of A and the columns of B associated with the expert assigned to that input sample.

More formally, the (i, j)-th element of C is computed as follows:

C[i, j] = \sum_{k=1}^{k} A[i, k] * B[g,k, j]

where k ranges over the columns of B associated with the expert assigned to the i-th input sample.

The key characteristic of Ragged Dot is that the rows of A and slices of B in the 0th dimension (of size g) are grouped into disjoint sets, with each set corresponding to an expert. This grouping structure allows the routing mechanism to efficiently assign input samples to the appropriate experts.

Requirements

Arguments:

  • lhs: (m, k) shaped array with integer or floating point element type. (REQUIRED)
  • rhs: (g1, k, n) shaped array with integer or floating point element type, where g1 denotes the number of local groups. (REQUIRED)
  • group_sizes: (g2,) shaped array with integer element type, where g2 denotes number of groups. The ith element indicates the size of ith group. (REQUIRED)
  • precision: consistent with precision argument for jax.lax.dot. (OPTIONAL).
  • preferred_element_type: the element type (jnp.dtype) for the output array. Consistent with preferred_element_type argument for jax.lax.dot. (OPTIONAL).
  • group_offset: (1,) shaped array. Indicates the group in group_sizes to start computing from. (OPTIONAL) If not specified, defaults to [0].
  • existing_output: (m, n) shaped array with elements of type preferred_element_type, where the output should we written to. (OPTIONAL) defaults to None, in which case the output should be written to a newly allocated array.

Results:

  • (m, n) shaped array with preferred_element_type element type.

Preconditions:

  1. group_sizes = [s_1, ..., s_g2], where s_1 + ... + s_g2 <= m

  2. g1 <= g2 (number of local groups <= number of groups).

If g1 == g2, group_offsets must be either unspecified or explicitly set to [0].

If g1 < g2, group_offsets must contain a single value in [0, g2) where group_offsets[0] + g1 <= g2. In this case, we perform g1 dots, where irrelevant slices of the results remain unchanged.

PiperOrigin-RevId: 623910451
Copy link

google-cla bot commented Apr 25, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 25, 2024

Assigning @mattjj, who has thought a bit about how to support ragged operations more broadly in JAX.

@froystig froystig self-requested a review April 26, 2024 13:02
precision=precision,
preferred_element_type=preferred_element_type,
)
return np.sum(result, axis=0, dtype=result.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using ordinary numpy and this is just a reference implementation, can you put it in lax_reference.py instead?

If this is meant to be a lax (i.e. HLO-ish) implementation, can you avoid calling ordinary numpy and just call jnp or lax functions instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we already have a reference implementation in lax_reference.py, I'm not sure what this one is for (unless it's meant to be turned into one that doesn't call ordinary numpy).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant to be a lax implementation that is lowerable to HLO. I modified the implementation so that no numpy functions are used.

I am quite surprised that when using numpy (np.cumsum and np.sum) in the lax.ragged_dot, it is still lowered to HLO as expected, yielding:

HloModule jit_wrapped_fun

cumsum.12 {
  ROOT Arg_0.13 = s32[1]{0} parameter(0)
}

_cumulative_reduction.14 {
  Arg_0.15 = s32[1]{0} parameter(0)
  ROOT call.16 = s32[1]{0} call(Arg_0.15), to_apply=cumsum.12
}

region_0.25 {
  Arg_0.26 = f32[] parameter(0)
  Arg_1.27 = f32[] parameter(1)
  ROOT add.28 = f32[] add(Arg_0.26, Arg_1.27)
}

ENTRY main.31 {
  constant.6 = s32[] constant(0)
  broadcast.7 = s32[1,5,4]{2,1,0} broadcast(constant.6), dimensions={}
  iota.10 = s32[5]{0} iota(), iota_dimension=0
  broadcast.11 = s32[1,5,4]{2,1,0} broadcast(iota.10), dimensions={1}
  compare.19 = pred[1,5,4]{2,1,0} compare(broadcast.7, broadcast.11), direction=LE
  Arg_2.3 = s32[1]{0} parameter(2)
  call.17 = s32[1]{0} call(Arg_2.3), to_apply=_cumulative_reduction.14
  broadcast.18 = s32[1,5,4]{2,1,0} broadcast(call.17), dimensions={0}
  compare.20 = pred[1,5,4]{2,1,0} compare(broadcast.11, broadcast.18), direction=LT
  and.21 = pred[1,5,4]{2,1,0} and(compare.19, compare.20)
  Arg_0.1 = bf16[5,4]{1,0} parameter(0)
  reshape.9 = bf16[1,5,4]{2,1,0} reshape(Arg_0.1)
  constant.4 = bf16[] constant(0)
  broadcast.5 = bf16[1,5,4]{2,1,0} broadcast(constant.4), dimensions={}
  select.22 = bf16[1,5,4]{2,1,0} select(and.21, reshape.9, broadcast.5)
  Arg_1.2 = bf16[1,4,3]{2,1,0} parameter(1)
  dot.23 = bf16[1,5,3]{2,1,0} dot(select.22, Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
  convert.24 = f32[1,5,3]{2,1,0} convert(dot.23)
  constant.8 = f32[] constant(0)
  reduce.29 = f32[5,3]{1,0} reduce(convert.24, constant.8), dimensions={0}, to_apply=region_0.25
  ROOT convert.30 = bf16[5,3]{1,0} convert(reduce.29)
} // main.31

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite surprised that when using numpy (np.cumsum and np.sum) in the lax.ragged_dot, it is still lowered to HLO as expected, yielding:

Yeah interesting! Me too. I haven't read this whole PR yet, but I guess it must be that, at least in the tests you have, the numpy stuff was all trace-time static data, i.e. stuff that could be evaluated while running the Python? In that case, writing original numpy calls is like saying "I want this to be evaluated at Python tracing time, not staged out into the HLO".

But I would expect that in general the sizes involved here would be dynamic, ie not available until runtime.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, all uses of np are removed now. Thank you.

@piotrfilipiuk
Copy link
Author

  • Removed uses of numpy functions from lax.ragged_dot
  • Added _CompileAndCheck tests in addition to self._CheckAgainstNumpy.

lhs_shape = (m, k)
rhs_shape = (num_groups, k, n)
def group_sizes(m, num_groups):
ends_no_final = jnp.sort(np.random.choice(m, size=num_groups - 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the random choice here use rng, or otherwise should the numpy RNG being used be seeded first?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, switched to using self.rng().choice() instead.

tests/lax_test.py Outdated Show resolved Hide resolved
@piotrfilipiuk
Copy link
Author

PTAL.

@piotrfilipiuk
Copy link
Author

Implemented the ragged_dot as JAX primitive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants