Skip to content

Commit

Permalink
[export] Add backwards compatibility test for Pallas call on GPUs.
Browse files Browse the repository at this point in the history
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.

Also enable serialization of such calls.

PiperOrigin-RevId: 630033989
  • Loading branch information
gnecula authored and jax authors committed May 2, 2024
1 parent 2730cf3 commit b40a310
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 7 deletions.
5 changes: 4 additions & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ py_library(
py_library(
name = "internal_export_back_compat_test_data",
testonly = 1,
srcs = glob(["_src/internal_test_util/export_back_compat_test_data/*.py"]),
srcs = glob([
"_src/internal_test_util/export_back_compat_test_data/*.py",
"_src/internal_test_util/export_back_compat_test_data/pallas/*.py",
]),
visibility = [
":internal",
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from numpy import array, float32


# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_05_02 = dict(
testdata_version=1,
platform='cuda',
custom_call_targets=['__gpu$xla.gpu.triton'],
serialized_date=datetime.date(2024, 5, 2),
inputs=(array([0., 1., 2., 3., 4., 5., 6., 7.], dtype=float32),),
expected_outputs=(array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32),),
mlir_module_text=r"""
#loc1 = loc("x")
#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":43:13)
#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<8xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @wrapped(%arg0) : (tensor<8xf32>) -> tensor<8xf32> loc(#loc3)
return %0 : tensor<8xf32> loc(#loc)
} loc(#loc)
func.func private @wrapped(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))) -> (tensor<8xf32> {mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 8 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIRgooglex-trunk\00\01-\07\01\05\09\17\01\03\0F\03\0D\13\17\1B\1F#'\05\09+/37\03O1\0B\01-\07\0F\0F\0F\0F\13\13\13\0B\0F\0F\13\0B\0F\0B\0B\0B\0B\1F\0B\0B\13\05\05YY\01\09\0F\07\17\0B\03\035\02\16\02\1F\11\01\05\1D)+\1D#\0F\11\01\01#\01\01\01\03\03\19\1B\17\11U'\05\1D\11\07\00\1D'\0F\01\05\0D\0D\05\1F\11\01\81\0D\05\05!\05#\05%\13\03\10\00\00\E0\0F\05'\05)\17\11U\11#arith.overflow<none>\00#arith.fastmath<none>\00\01\02\02\0B\05\05\09\09\01\01\09!tt.ptr<f32>\00\04\D2\02\05\01P\01\01\07\04\AE\02\03\01\05\07P\01\03\07\04\82\02\03+W\05\11\11\00\09B\01\05\03\01\0FB\01\07\03\01\11F\01\09\03\01\05\05\07\0FB\01\07\03\01\11F\01\09\03\01\05\05\0B\0FB\07\05\03\01\13F\07\09\03\01\05\0F\09\0FB\07\07\03\01\11F\07\09\03\01\05\11\13\03\06\07\03\09\05\01\15\05F\07\0B\03\03\03\17\0FB\15\0D\03\03\15F\15\0F\03\03\05\19\1B\0FB\05\05\03\01\13F\05\09\03\01\05\1F\0D\0FB\05\07\03\01\11F\05\09\03\01\05!#\03\06\05\03\09\05\03%\05F\05\0B\03\03\03'\0BD\05\11\05'\1D\0D\00\01\06\03\01\05\01\00\F2\05+\A5\0B\A3\0F\11!\85\0B\0B\0B\13\0F\0D\1F\0B\0B\0F\0F\0D\07\11builtin\00tt\00arith\00module\00addptr\00load\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00addf\00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\00tt.divisibility\00add_one\00public\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00/add\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00\08C\13\05\01\01\0B/\1D\01\1FC\03\09\03\03\03[\11\17\07\07'\01\07\01\03\03%\03_\07\17\07\07", name = "add_one", num_stages = 3 : i32, num_warps = 4 : i32}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xf32>) -> tensor<8xf32> loc(#loc4)
return %0 : tensor<8xf32> loc(#loc3)
} loc(#loc3)
} loc(#loc)
#loc = loc(unknown)
#loc4 = loc("jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xaf\x8b\x11\x01G\x07\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03E\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x13\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0bK\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f/\x01\x05\x0b\x0f\x03\r\x13\x07\x17\x07\x13\x07\x02\xee\x03\x1f\x1d#\x11\x05\x0f\x11\x03\x05\x05\x11\x05\x13\x05\x15\x05\x17\x17%W\x1b\x03\t\x15\x17\x19\x07\x1b\x07\x05\x1d\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\tG\x0bM\r]\x05c\x0fe\x03\x0b\tG\x0bM\rG\x05Q\x0fg\x05!\x05#\x03\x13)i+O-k/S1U3m5Y7S9Y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x055\x1d=\x11\x057\x1dA\x01\x059\x03\x03EQ\x05;\x03\x03[\x1d=\x1d?#\t\x1dA\x1dC\x03\x01\x05\x01\x13\x07\x05\x03\x03\x89\r\x03IK\x03\x03_\r\x05aOIK\x1dE\x1dG\x1dI\x1dK\x0b\x03\x1dM\r\x11oUqsuWwWy{}\x7f\x81\x83\x85\x87\x1dO\x1dQ\x13\x07!\x1dS\x1dU\x1dW\x1dY\x1d[\x1d]\x1d_\x13\x07\r\x1da\x13\x07\x11\x1f\r\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03!\x0b\x1b\x11\x03\x05\x03\x05\t)\x03\x05\x0f\x13\x04s\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x1f\x07\x03\x05\x0b\x03\x05?\t\x07\x03C\x03\x05\x03\x01\x05\x04\x01\x03\x03\x03\x11\x03!\x07\x03\x05\x0b\x03\x05\x03\x07\x07;'\x03\x05\x03\x01\x05\x04\x03\x03\x03\x06\x03\x01\x05\x01\x00\x12%c\x15\x17\x11\x0b\xfe\x0c\x07\x0f\x0f\x0f\r+\x11\x0f\x0b!\x11\x03\x11#\x0f\x05\xd2\n\x1f/!)!)#\x1f\x19\x85j\x03\x13%)9\x1f\x15\x1d\x15\x13\x11\x1f\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]\x00x\x00callee\x00mhlo.layout_mode\x00default\x00\x00wrapped\x00jax.result_info\x00main\x00public\x00private\x00__gpu$xla.gpu.triton\x00debug\x00grid_x\x00grid_y\x00grid_z\x00ir\x00ML\xefR\rMLIRgooglex-trunk\x00\x01-\x07\x01\x05\t\x17\x01\x03\x0f\x03\r\x13\x17\x1b\x1f#'\x05\t+/37\x03O1\x0b\x01-\x07\x0f\x0f\x0f\x0f\x13\x13\x13\x0b\x0f\x0f\x13\x0b\x0f\x0b\x0b\x0b\x0b\x1f\x0b\x0b\x13\x05\x05YY\x01\t\x0f\x07\x17\x0b\x03\x035\x02\x16\x02\x1f\x11\x01\x05\x1d)+\x1d#\x0f\x11\x01\x01#\x01\x01\x01\x03\x03\x19\x1b\x17\x11U'\x05\x1d\x11\x07\x00\x1d'\x0f\x01\x05\r\r\x05\x1f\x11\x01\x81\r\x05\x05!\x05#\x05%\x13\x03\x10\x00\x00\xe0\x0f\x05'\x05)\x17\x11U\x11#arith.overflow<none>\x00#arith.fastmath<none>\x00\x01\x02\x02\x0b\x05\x05\t\t\x01\x01\t!tt.ptr<f32>\x00\x04\xd2\x02\x05\x01P\x01\x01\x07\x04\xae\x02\x03\x01\x05\x07P\x01\x03\x07\x04\x82\x02\x03+W\x05\x11\x11\x00\tB\x01\x05\x03\x01\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x07\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x0b\x0fB\x07\x05\x03\x01\x13F\x07\t\x03\x01\x05\x0f\t\x0fB\x07\x07\x03\x01\x11F\x07\t\x03\x01\x05\x11\x13\x03\x06\x07\x03\t\x05\x01\x15\x05F\x07\x0b\x03\x03\x03\x17\x0fB\x15\r\x03\x03\x15F\x15\x0f\x03\x03\x05\x19\x1b\x0fB\x05\x05\x03\x01\x13F\x05\t\x03\x01\x05\x1f\r\x0fB\x05\x07\x03\x01\x11F\x05\t\x03\x01\x05!#\x03\x06\x05\x03\t\x05\x03%\x05F\x05\x0b\x03\x03\x03'\x0bD\x05\x11\x05'\x1d\r\x00\x01\x06\x03\x01\x05\x01\x00\xf2\x05+\xa5\x0b\xa3\x0f\x11!\x85\x0b\x0b\x0b\x13\x0f\r\x1f\x0b\x0b\x0f\x0f\r\x07\x11builtin\x00tt\x00arith\x00module\x00addptr\x00load\x00func\x00get_program_id\x00store\x00return\x00constant\x00muli\x00addi\x00addf\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00tt.divisibility\x00add_one\x00public\x00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00/add\x00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00\x08C\x13\x05\x01\x01\x0b/\x1d\x01\x1fC\x03\t\x03\x03\x03[\x11\x17\x07\x07'\x01\x07\x01\x03\x03%\x03_\x07\x17\x07\x07\x00name\x00add_one\x00num_stages\x00num_warps\x00",
xla_call_module_version=9,
nr_devices=1,
) # End paste
13 changes: 8 additions & 5 deletions jax/_src/internal_test_util/export_back_compat_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
JAX serialized format, we need to guarantee that custom calls continue to
work as before. We test this here.
The tests in this file refer to the test data in ./back_compat_testdata.
The tests in this file refer to the test data in
jax/_src/internal_test_util/export_back_compat_test_data.
There is one test for each version of a custom call target, e.g.,
`test_ducc_fft` tests the FFT custom calls on CPU.
Only custom call targets tested here should be listed in
Expand All @@ -32,11 +34,12 @@
Write the JAX function `func` that exercises the custom call `foo_call` you
want, then pick some inputs, and then add this to the new test to get started.
Add the following code to your test file, e.g., `export_back_compat_test.py`.
import dataclasses
from jax._src.internal_test_util import export_back_compat_test_util as bctu
class BackCompatTest(bctu.CompatTestBase)
class CompatTest(bctu.CompatTestBase)
...
def test_foo_call(self):
Expand All @@ -48,13 +51,13 @@ def func(...): ...
The test will fail, but will save to a file the test data you will need. The
file name will be printed in the logs. Create a new
file ./back_compat_testdata/foo_call.py and paste the test data that
you will see printed in the logs.
file jax/_src/internal_test_util/export_back_compat_test_data/foo_call.py
and paste the test data that you will see printed in the logs.
Name the literal `data_YYYYY_MM_DD` to include the date of serializaton
(for readability only). Then add to this file:
from jax.experimental.jax2tf.tests.back_compat_testdata import foo_call
from jax._src.internal_test_util.export_back_compat_test_data import foo_call
then update `test_custom_call_coverage`, and then update your `test_foo_call`:
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def _check_lowering(lowering) -> None:
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"dynamic_ducc_fft", "cu_threefry2x32",
"__gpu$xla.gpu.triton", # Pallas call on GPU
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on CPU
Expand Down
5 changes: 4 additions & 1 deletion tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import jax
from jax import lax
from jax.experimental.export import _export

from jax._src.internal_test_util import export_back_compat_test_util as bctu

from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft
Expand Down Expand Up @@ -64,8 +65,9 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_legacy_prng_key='allow',
@jtu.with_config(jax_legacy_prng_key="allow",
jax_debug_key_reuse=False,
jax_include_full_tracebacks_in_locations=False,
jax_threefry_gpu_kernel_lowering=True)
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
Expand Down Expand Up @@ -131,6 +133,7 @@ def test_custom_call_coverage(self):
covered_targets = covered_targets.union({
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
"tpu_custom_call", # tested separately
"__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered,
Expand Down
32 changes: 32 additions & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,35 @@ jax_test(
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("absl/flags") + py_deps("numpy") + py_deps("hypothesis"),
)

jax_test(
name = "export_back_compat_pallas_test",
srcs = ["export_back_compat_pallas_test.py"],
config_tags_overrides = {
"gpu_a100_x32": {
"ondemand": False, # Include in presubmit.
},
},
disable_backends = [
"cpu",
"tpu",
],
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100_x32",
"gpu_pjrt_c_api",
],
enable_configs = [
"gpu_a100_x32",
"gpu_h100_x32",
],
tags = [],
deps = [
"//jax:internal_export_back_compat_test_data",
"//jax:internal_export_back_compat_test_util",
"//jax:pallas_gpu",
],
)
68 changes: 68 additions & 0 deletions tests/pallas/export_back_compat_pallas_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for backwards compatibility of exporting code with Pallas custom calls.
See the export_back_compat_test_util module docstring for how to setup and
update these tests.
"""

from absl.testing import absltest

import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.internal_test_util import export_back_compat_test_util as bctu

from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one

from jax.experimental import pallas as pl
try:
from jax.experimental.pallas import gpu as plgpu
except ImportError:
plgpu = None
import jax.numpy as jnp


config.parse_flags_with_absl()


@jtu.with_config(jax_include_full_tracebacks_in_locations=False)
class CompatTest(bctu.CompatTestBase):

def setUp(self):
if jax.config.x64_enabled:
self.skipTest("Only works in 32-bit")
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
(plgpu is None or plgpu.get_compute_capability(0) < 80)):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()

def test_cuda_add_one(self):
def func(x):
def add_one(x_ref, o_ref):
o_ref[0] = x_ref[0] + 1
return pl.pallas_call(add_one,
out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
in_specs=[pl.BlockSpec(lambda i: i, (1,))],
out_specs=pl.BlockSpec(lambda i: i, (1,)),
grid=8)(x)
data = self.load_testdata(cuda_add_one.data_2024_05_02)

self.run_one_test(func, data)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b40a310

Please sign in to comment.