Skip to content

Commit

Permalink
Merge some loops in device_put since it's trivial to do so
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626546322
  • Loading branch information
yashk2810 authored and jax authors committed Apr 20, 2024
1 parent 0943eb3 commit 1837b43
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,7 +2439,7 @@ def make_jaxpr_f(*args, **kwargs):

def _infer_src_sharding(src, x) -> Sharding | None:
if src is not None:
return src
return src # type: ignore
if isinstance(x, array.ArrayImpl):
return x.sharding
elif isinstance(x, core.Tracer):
Expand Down Expand Up @@ -2493,21 +2493,20 @@ def device_put(
isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
(src is None or
isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
for leaf in tree_leaves(x):
_check_sharding(shaped_abstractify(leaf), s=device)
return tree_map(
lambda y: dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y)), x)
def _map(y):
_check_sharding(shaped_abstractify(y), s=device)
return dispatch.device_put_p.bind(
y, device=device, src=_infer_src_sharding(src, y))
return tree_map(_map, x)

x_flat, treedef = tree_flatten(x)
device_flat = flatten_axes("device_put device", treedef, device)
src_flat = flatten_axes("device_put source", treedef, src)
for x_leaf, device_leaf in zip(x_flat, device_flat):
_check_sharding(shaped_abstractify(x_leaf), device_leaf)
out_flat = [
dispatch.device_put_p.bind(xf, device=d, src=_infer_src_sharding(s, xf))
for xf, d, s in zip(x_flat, device_flat, src_flat)
]
out_flat = []
for xf, d, s in zip(x_flat, device_flat, src_flat):
_check_sharding(shaped_abstractify(xf), d)
out_flat.append(dispatch.device_put_p.bind(
xf, device=d, src=_infer_src_sharding(s, xf)))
return tree_unflatten(treedef, out_flat)


Expand Down

0 comments on commit 1837b43

Please sign in to comment.