Skip to content

Commit

Permalink
Remove explicit pallas trace_start/trace_stop primitives. These are n…
Browse files Browse the repository at this point in the history
…ow automatically inserted with the usage of jax.named_scope.

PiperOrigin-RevId: 628553413
  • Loading branch information
justinjfu authored and jax authors committed Apr 27, 2024
1 parent 60af226 commit 0b5f3f8
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 77 deletions.
1 change: 0 additions & 1 deletion jax/_src/pallas/mosaic/__init__.py
Expand Up @@ -14,7 +14,6 @@

"""Module for Mosaic lowering of Pallas call."""

from jax._src.api import named_scope as trace
from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
Expand Down
19 changes: 2 additions & 17 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -648,9 +648,9 @@ def write_env(var: jax_core.Var, val):
current_name_stack, name_stack)
current_name_stack = name_stack
for _ in popped:
_trace_stop_lowering_rule(rule_context)
tpu.TraceStopOp()
for name in pushed:
_trace_start_lowering_rule(rule_context, message=name, level=10)
tpu.TraceStartOp(message=name, level=10)

try:
ans = lowering_rules[eqn.primitive](
Expand Down Expand Up @@ -2152,21 +2152,6 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):

lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule

def _trace_start_lowering_rule(
ctx: LoweringRuleContext, *, message: str, level: int
):
return tpu.TraceStartOp(message=message, level=level).results


lowering_rules[tpu_primitives.trace_start_p] = _trace_start_lowering_rule


def _trace_stop_lowering_rule(ctx: LoweringRuleContext):
return tpu.TraceStopOp().results


lowering_rules[tpu_primitives.trace_stop_p] = _trace_stop_lowering_rule


def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
if isinstance(aval, pl_core.AbstractMemoryRef):
Expand Down
18 changes: 9 additions & 9 deletions jax/_src/pallas/mosaic/pipeline.py
Expand Up @@ -309,7 +309,7 @@ def do_and_advance_buffers():
if accum_allocation is not None:

def accum():
with tpu_primitives.trace("ep_accum_copy"):
with jax.named_scope("ep_accum_copy"):
accum_dtype = jnp.float32
if vmem_ref.dtype == jnp.int32:
accum_dtype = jnp.int32
Expand All @@ -330,7 +330,7 @@ def dont_advance_buffers():
if accum_allocation is not None:

def accum():
with tpu_primitives.trace("ep_accum_store"):
with jax.named_scope("ep_accum_store"):

def zero_accum():
accum_vmem_ref = accum_allocation.vmem_ref
Expand Down Expand Up @@ -852,7 +852,7 @@ def fori_loop_body(
next_indices = _get_next_indices(grid, indices)
copy_indices = (prev_indices, indices, next_indices)

with tpu_primitives.trace("ep_wait_input"):
with jax.named_scope("ep_wait_input"):
input_copy_args = [
pipeline_specs.input,
pipeline_refs.input,
Expand Down Expand Up @@ -918,7 +918,7 @@ def start_next_iteration_in_block_copies():
)
return next_in_buffers, next_in_out_buffers

@tpu_primitives.trace("ep_run_epilogue")
@jax.named_scope("ep_run_epilogue")
def run_epilogue():
if epilogue is None:
return pipeline_buffers.input, pipeline_buffers.in_out
Expand Down Expand Up @@ -948,7 +948,7 @@ def run_epilogue():
start_next_iteration_in_block_copies,
)

with tpu_primitives.trace("ep_kernel"):
with jax.named_scope("ep_kernel"):

def grab_body_ref(
spec_with_nones,
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def accum():
in_out_existing_allocations,
)

with tpu_primitives.trace("ep_wait_output"):
with jax.named_scope("ep_wait_output"):

def run_out_prologue():
if out_prologue is not None:
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def run_out_prologue():
force_skip=skip_out_prologue_wait,
)

@tpu_primitives.trace("ep_wait_prev_iteration_out_block_copies")
@jax.named_scope("ep_wait_prev_iteration_out_block_copies")
def wait_prev_iteration_out_block_copies():
tree_util.tree_map(
partial(
Expand Down Expand Up @@ -1169,8 +1169,8 @@ def set_buffer_ref(buffer_ref, buffer):
pipeline_buffers,
)

with tpu_primitives.trace("ep_end_pipeline"):
with tpu_primitives.trace("ep_wait_output"):
with jax.named_scope("ep_end_pipeline"):
with jax.named_scope("ep_wait_output"):
if out_epilogue is not None:
skip_out_epilogue_wait = out_epilogue(
PipelineCallbackArgs(
Expand Down
38 changes: 0 additions & 38 deletions jax/_src/pallas/mosaic/primitives.py
Expand Up @@ -15,7 +15,6 @@
"""Module for Pallas:TPU-specific JAX primitives and functions."""
from __future__ import annotations

import contextlib
import dataclasses
import enum
from typing import Any, Callable
Expand Down Expand Up @@ -101,10 +100,6 @@ def _bitcast(x):

mlir.register_lowering(bitcast_p, _bitcast_lowering_rule)

trace_start_p = jax_core.Primitive('trace_start')
trace_start_p.multiple_results = True


roll_p = jax_core.Primitive("roll")


Expand Down Expand Up @@ -157,39 +152,6 @@ def _roll(x):
mlir.register_lowering(roll_p, _roll_lowering_rule)


@trace_start_p.def_impl
def _trace_start_impl(*, message: str, level: int):
del message, level
return []

@trace_start_p.def_abstract_eval
def _trace_start_abstract_eval(*, message: str, level: int):
del message, level
return []

mlir.register_lowering(trace_start_p, lambda ctx, **_: [])


trace_stop_p = jax_core.Primitive('trace_stop')
trace_stop_p.multiple_results = True

@trace_stop_p.def_impl
def _trace_stop_impl():
return []

@trace_stop_p.def_abstract_eval
def _trace_stop_abstract_eval():
return []

mlir.register_lowering(trace_stop_p, lambda ctx: [])

@contextlib.contextmanager
def trace(message: str, level: int = 10):
trace_start_p.bind(message=message, level=level)
yield
trace_stop_p.bind()


run_scoped_p = jax_core.Primitive('run_scoped')
run_scoped_p.multiple_results = True

Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/pallas/ops/tpu/all_gather.py
Expand Up @@ -63,15 +63,15 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
my_id = lax.axis_index(axis_name)
# TODO(sharadmv): could speed this up having the first remote DMA go from
# x_ref->o_ref immediately instead of a blocking HBM copy.
with pltpu.trace("initial_copy"):
with jax.named_scope("initial_copy"):
pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait()

with pltpu.trace("neighbour_lookup"):
with jax.named_scope("neighbour_lookup"):
axis_size = lax.psum(1, axis_name)
left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left")
right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right")

with pltpu.trace("main_barrier"):
with jax.named_scope("main_barrier"):
sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(sem, 1, device_id=left_neighbor)
pltpu.semaphore_signal(sem, 1, device_id=right_neighbor)
Expand All @@ -86,7 +86,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
right_slice = pl.ds(shard_size // 2, shard_size // 2)
slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot)
if right_dma:
with pltpu.trace("wait_right_dma"):
with jax.named_scope("wait_right_dma"):
right_dma.wait()
right_dma = pltpu.async_remote_copy(
o_ref.at[slot, right_slice],
Expand All @@ -100,7 +100,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
left_slice = pl.ds(0, shard_size // 2)
slot = lax.rem(left_slot, axis_size)
if left_dma:
with pltpu.trace("wait_left_dma"):
with jax.named_scope("wait_left_dma"):
left_dma.wait()
left_dma = pltpu.async_remote_copy(
o_ref.at[slot, left_slice],
Expand All @@ -109,7 +109,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
recv_sem[0],
device_id=left_neighbor,
)
with pltpu.trace("wait_all_dma"):
with jax.named_scope("wait_all_dma"):
assert right_dma is not None
assert left_dma is not None
right_dma.wait()
Expand Down
1 change: 0 additions & 1 deletion jax/experimental/pallas/tpu.py
Expand Up @@ -43,5 +43,4 @@
from jax._src.pallas.mosaic import semaphore_read
from jax._src.pallas.mosaic import semaphore_signal
from jax._src.pallas.mosaic import semaphore_wait
from jax._src.pallas.mosaic import trace
from jax._src.tpu_custom_call import CostEstimate
10 changes: 5 additions & 5 deletions tests/pallas/pallas_call_tpu_test.py
Expand Up @@ -2314,7 +2314,7 @@ def prologue(prologue_args: pltpu.PipelineCallbackArgs):
del prologue_args

@pl.when(is_start)
@pltpu.trace('sync_and_bwd_init')
@jax.named_scope('sync_and_bwd_init')
def _sync_and_bwd_init():
# barrier at start
barrier_sem = pltpu.get_barrier_semaphore()
Expand All @@ -2327,7 +2327,7 @@ def _sync_and_bwd_init():
initial_bwd_copy.wait()

@pl.when(jnp.logical_and(step != steps - 1, phase == 0))
@pltpu.trace('send_next_dma')
@jax.named_scope('send_next_dma')
def _send_next_dma():
bwd_copy.start()
@pl.when(jnp.logical_not(is_start))
Expand All @@ -2344,13 +2344,13 @@ def _send_next_fwd_dma():
def epilogue(epilogue_args: pltpu.PipelineCallbackArgs):

@pl.when(is_start)
@pltpu.trace('fwd_init')
@jax.named_scope('fwd_init')
def _fwd_init():
initial_fwd_copy.wait()
fwd_copy.start()

@pl.when(jnp.logical_and(step != steps - 1, phase == 1))
@pltpu.trace('wait_on_prev_dma')
@jax.named_scope('wait_on_prev_dma')
def _wait_on_prev_dma():
bwd_copy.wait()
fwd_copy.wait()
Expand Down Expand Up @@ -2388,7 +2388,7 @@ def prefetch_pipeline_inputs():
),
)

with pltpu.trace('dots'):
with jax.named_scope('dots'):

pipeline(
lhs_scratch_ref.at[phase, working_slot],
Expand Down

0 comments on commit 0b5f3f8

Please sign in to comment.