Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmap support #645

Open
6 tasks
josevalim opened this issue Feb 17, 2022 · 2 comments
Open
6 tasks

pmap support #645

josevalim opened this issue Feb 17, 2022 · 2 comments

Comments

@josevalim
Copy link
Collaborator

josevalim commented Feb 17, 2022

The goal is to add Nx.Defn.pmap. This is a discussion of its API.

Input

When it comes to the input, the first option is to automatically shard the input. For example:

Nx.Defn.pmap([arg1, arg2, arg3], fun)

can shard the first argument based on the number of the devices. We can make the sharding dimensions customizable too. Something like:

Nx.Defn.pmap([arg1, arg2, arg3], fun, shards: [{0, 2}, {1, 0}])

will shard the first argument at axis 2 and the second argument at axis 0.

The second option is to allow a list of lists of already sharded tensors to be given. To convert the data to this format, one can use Nx.to_batched_list/2 but perhaps we can also add Nx.shard.

Nx.Defn.pmap([[arg_a1, arg_a2, arg_a3], [arg_b1, arg_b2, arg_b3], ...], fun)

Output

When it comes to the output, we have two options. The most logical option, thinking about Elixir, is for it to return a list of results of the same size as the list of inputs. In the GPU/TPU case, the inputs will remain allocated on each GPU. This is great, especially with the second input API above because you can easily call Nx.Defn.pmap, with a separate list of inputs, to continue performing computations:

outputs = Nx.Defn.pmap(inputs, fun1)
Nx.Defn.pmap(outputs, fun2)

If your goal is to put the tensors back together into one, then you need to call Nx.stack/1. However, given each tensor belongs to a separate device, perhaps you will need something like this:

Nx.Defn.pmap(fun1, inputs)
|> Enum.map(&Nx.backend_transfer(EXLA.Backend, device: 0))
|> Nx.stack()

Maybe we should make it so Nx.stack() automatically performs the transfer across devices (this is something we can discuss separately).

The other approach is to handle it the same way as JAX: it creates a separate "tensor backend" that knows in practice the tensor belongs to n-other backends. The benefit is that it can still present the data as one and perhaps encapsulate the backend transfer code above, at the cost of one additional abstraction.

Personal thoughts section

My personal thought is that the most Elixir-like approach is to have a list of arguments as input and a list of outputs. We can add functions such as Nx.shard/2 to make it easier to shard existing values. We can also add Nx.Defn.num_shards(opts) to return the number of shards that the current compiler supports.

If we feel like we want to support first-class sharding, then we can add Nx.smap, which stands for shard map, that automatically does the sharding for you but is built on top of pmap.

TODO

  • Implement device transfers in EXLA
  • Discuss if device transfers should happen automatically in EXLA
  • Add Nx.Defn.pmap
  • Add Nx.Defn.num_shards(opts) (any better names than shards?)
  • Potentially add Nx.shard
  • Potentially add Nx.Defn.smap
@cigrainger
Copy link
Member

I agree that list of arguments as input and list of outputs makes sense. Especially for the reason you suggested re: passing on the outputs. I also like the idea of a shard map function.

I just want to clarify a few things for myself re: terminology and capabilities as well so I'm clear on where this would fit in. I'm not super familiar with XLA, and I'm thinking in terms of 'data parallel' approaches like PyTorch and optimizer state sharding like Fairscale/ZeRO.

Am I right in thinking that really both of these would be enabled by Nx.Defn.pmap? The premise here is to invoke the right XLA ops such that we split a tensor into N shards, transfer them onto separate devices, apply a function to each shard on each device, then (optionally) gather the results back?

So, for example, 'data parallel' a la PyTorch could be achieved by putting model replicas on each device then using Nx.Defn.pmap on a batch to shard it across devices, run the forward pass, run the backward pass, then collect and aggregate the gradients and the aggregate is then passed to the optimizer.

In that case, being able to easily chain Nx.Defn.pmap functions together would be very convenient.

@josevalim
Copy link
Collaborator Author

So, for example, 'data parallel' a la PyTorch could be achieved by putting model replicas on each device then using Nx.Defn.pmap on a batch to shard it across devices, run the forward pass, run the backward pass, then collect and aggregate the gradients and the aggregate is then passed to the optimizer.

Correct, but we also want to add psum, pmap, and friends so we "broadcast" tensor updates to all GPUs without having to circle in and out of Nx.Defn.pmap. Going in and out of pmap will only be a necessity for when the built-in data parallel primitives on Nx.Defn are not enough. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants