diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 533263a282ffbf..d90ff627336285 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -2072,6 +2072,7 @@ py_strict_library( ":array_ops_stack", ":bitwise_ops_gen", ":data_flow_ops_gen", + ":logging_ops_gen", ":math_ops_gen", ":nn_ops_gen", ":sparse_ops_gen", diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 7941ae9bf6594a..510865596fe4b5 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -87,6 +87,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_bitwise_ops from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops @@ -4678,6 +4679,200 @@ def sparse_segment_sum( ) +@tf_export("sparse.sampled_addmm", v1=[]) +def sampled_addmm( + indices, + values, + dense_shape, + mat1, + mat2, + beta=1.0, + alpha=1.0, + output_type=dtypes.float32, +): + """Performs the sampled matrix multiplication of two dense matrices. + + Multiplies matrix `mat1` by matrix `mat2` at the locations defined by + `indices`. The product is scaled and added to `values`, + producing `alpha` * (`mat1` @ `mat2`) * spy(`indices`) + `beta` * `values`. + + The function `spy(indices)` is the sparsity pattern matrix derived from + `indices`. + + The `mat1` and `mat2` inputs must be tensors of rank >= 2 where the inner 2 + dimensions specify valid matrix multiplication dimensions, and any further + dimensions specify matching batch size. + + The `indices`, `values`, and `dense_shape` inputs make up the components of a + `SparseTensor` which defines the sparsity pattern of the output. The sparsity + pattern has values of 1 at the positions defined by the `SparseTensor`, and 0 + elsewhere. + + The `alpha` and `beta` inputs are the scaling factors. + + The supported types for `values`, `mat1`, and `mat2` are: + `bfloat16`, `float16`, `float32`, `float64`. + + A simple 2-D tensor operation: + + >>> indices = tf.constant([0, 0, 1, 1], shape=[2, 2]) + >>> indices + + >>> values = tf.constant([0.5, 0.3]) + >>> values + + >>> dense_shape = tf.constant([2, 2]) + >>> dense_shape + + >>> mat1 = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=tf.float32) + >>> mat1 + + >>> mat2 = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=tf.float32) + >>> mat2 + + >>> tf.sparse.sampled_addmm(indices, values, dense_shape, mat1, mat2, + ... alpha=0.75, beta=0.25) + (, , + ) + + A batch operation: + + >>> indices = tf.constant([0, 1, 1, 0, 0, 0, 1, 0], shape=[2, 2, 2]) + >>> indices + + >>> values = tf.constant([3, 5, 2, 7], shape=[2, 2], dtype=tf.float32) + >>> values + + >>> dense_shape = tf.constant([2, 2]) + >>> dense_shape + + >>> mat1 = tf.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=tf.float32) + >>> mat1 + + >>> mat2 = tf.constant(np.arange(13, 25), shape=[2, 3, 2], dtype=tf.float32) + >>> mat2 + + >>> tf.sparse.sampled_addmm(indices, values, dense_shape, mat1, mat2, + ... alpha=0.75, beta=0.25) + (, , + ) + + Args: + indices: `tf.Tensor` containing coordinates for the rows and columns to be + multiplied. Must have rank > 1. + values: `tf.Tensor` containing the values to be scaled and added to the + sampled dot product. + dense_shape: `tf.Tensor` defining the dense shape of the output. + mat1: `tf.Tensor` to be multiplied. Must have rank > 1. + mat2: `tf.Tensor` to be multiplied. Must have rank > 1. + beta: Number to be multipled with `values`. Defaults to 1.0. + alpha: Number to be multiplied with the sampled dot product of `mat1` and + `mat2`. Defaults to 1.0. + output_type: The output datatype if needed. Defaults to float32. + + Returns: + A tuple representing the `SparseTensor` components of the result of the + operation. + + Raises: + ValueError: If `dense_shape` does not match the shape of the product. + """ + indices = ops.convert_to_tensor(indices) + values = ops.convert_to_tensor(values, dtype=output_type) + dense_shape = ops.convert_to_tensor(dense_shape, dtype=dtypes.int32) + mat1 = ops.convert_to_tensor(mat1, dtype=output_type) + mat2 = ops.convert_to_tensor(mat2, dtype=output_type) + + mat1_shape = tensor_util.constant_value(array_ops.shape(mat1)) + mat2_shape = tensor_util.constant_value(array_ops.shape(mat2)) + + dense_rows = mat1_shape[-2] + dense_cols = mat2_shape[-1] + + output_shape = array_ops_stack.stack([dense_rows, dense_cols]) + condition = reduce_all(equal(dense_shape, output_shape)) + + # Use dense_shape to validate input matrix shapes. + if context.executing_eagerly(): + if not condition: + raise ValueError( + f"Dense shape: {dense_shape} does not match " + f"output shape: {output_shape}" + ) + else: # not context.executing_eagerly() + dense_shape_static = tensor_util.constant_value(dense_shape) + output_shape_static = tensor_util.constant_value(output_shape) + if dense_shape_static is not None and output_shape_static is not None: + condition_static = np.all( + np.equal(dense_shape_static, output_shape_static) + ) + if not condition_static: + raise ValueError( + f"Dense shape: {dense_shape} does not match " + f"output shape: {output_shape}" + ) + + data = [ + "Dense shape: ", + dense_shape, + " does not match output shape: ", + output_shape, + ] + + gen_logging_ops._assert(condition, data, None, name="Assert") + + # Extract row and column indices. + batch_indices = indices[..., :-2] + row_indices = indices[..., :-1] + col_indices = array_ops.concat([batch_indices, indices[..., -1:]], axis=-1) + + # Calculate batch dimensions. + rank = tensor_util.constant_value(array_ops.rank(mat1)) + batch_dims = rank - 2 + + # Extract rows and columns. + rows = array_ops.gather_nd(mat1, row_indices, batch_dims=batch_dims) + cols = array_ops.gather_nd( + array_ops.matrix_transpose(mat2), col_indices, batch_dims=batch_dims + ) + + # Calculate dot product for the extracted rows and columns. + dot = reduce_sum(rows * cols, axis=-1) + return (indices, dot * alpha + values * beta, dense_shape) + + @tf_export("sparse.segment_sum", v1=[]) def sparse_segment_sum_v2( data, diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 1694537d68aaf6..08c16974e8c4ee 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -372,6 +372,120 @@ def testInvalidOutputTypeMatmul(self): self.evaluate(math_ops.matmul(a, b, output_type=dtypes.float32)) +@test_util.run_all_in_graph_and_eager_modes +class SampledADDMMTest(test_util.TensorFlowTestCase): + """Test for sampled_addmm.""" + + SUPPORTED_DTYPES = [ + dtypes.bfloat16, + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] + + def sampledADDMMRef( + self, + indices, + values, + dense_shape, + mat1, + mat2, + beta=1.0, + alpha=1.0, + output_type=dtypes.float32, + ): + dense = math_ops.matmul(mat1, mat2, output_type=output_type) + dense_vals = array_ops.gather_nd(dense, indices, batch_dims=dense.ndim - 2) + return alpha * dense_vals + beta * values + + def testSampledADDMM2D(self): + for dtype in self.SUPPORTED_DTYPES: + indices = constant_op.constant([[0, 0], [1, 1]]) + values = constant_op.constant([0.5, 0.3], dtype=dtype) + dense_shape = constant_op.constant([2, 2]) + mat1 = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype) + mat2 = constant_op.constant( + [7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=dtype + ) + alpha = 0.75 + beta = 0.25 + + _, res, _ = math_ops.sampled_addmm( + indices, + values, + dense_shape, + mat1, + mat2, + beta=beta, + alpha=alpha, + output_type=dtype, + ) + ref = self.sampledADDMMRef( + indices, + values, + dense_shape, + mat1, + mat2, + beta=beta, + alpha=alpha, + output_type=dtype, + ) + self.assertAllClose(res, ref, atol=1e-2) + + def testBatchSampledADDMM(self): + for dtype in self.SUPPORTED_DTYPES: + indices = constant_op.constant([[[0, 1], [1, 0]], [[0, 0], [1, 0]]]) + values = constant_op.constant([[3, 5], [2, 7]], dtype=dtype) + dense_shape = constant_op.constant([2, 2]) + mat1 = constant_op.constant( + np.arange(1, 13), shape=[2, 2, 3], dtype=dtype + ) + mat2 = constant_op.constant( + np.arange(13, 25), shape=[2, 3, 2], dtype=dtype + ) + alpha = 0.4 + beta = 0.6 + + _, res, _ = math_ops.sampled_addmm( + indices, + values, + dense_shape, + mat1, + mat2, + beta=beta, + alpha=alpha, + output_type=dtype, + ) + ref = self.sampledADDMMRef( + indices, + values, + dense_shape, + mat1, + mat2, + beta=beta, + alpha=alpha, + output_type=dtype, + ) + self.assertAllClose(res, ref, atol=1e-2) + + def testInvalidDenseShape(self): + for dtype in self.SUPPORTED_DTYPES: + indices = constant_op.constant([[[0, 1], [1, 0]], [[0, 0], [1, 0]]]) + values = constant_op.constant([[3, 5], [2, 7]], dtype=dtype) + dense_shape = constant_op.constant([1, 2]) + mat1 = constant_op.constant( + np.arange(1, 13), shape=[2, 2, 3], dtype=dtype + ) + mat2 = constant_op.constant( + np.arange(13, 25), shape=[2, 3, 2], dtype=dtype + ) + + with self.assertRaisesRegex(ValueError, "does not match output shape"): + math_ops.sampled_addmm( + indices, values, dense_shape, mat1, mat2, output_type=dtype + ) + + @test_util.run_all_in_graph_and_eager_modes class ModTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt index 8a08f208808cca..b5c536f9d062a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -80,6 +80,10 @@ tf_module { name: "retain" argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "sampled_addmm" + argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'mat1\', \'mat2\', \'beta\', \'alpha\', \'output_type\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \"\"], " + } member_method { name: "segment_mean" argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 77e582f4150771..7de081f853b5f5 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7242,7 +7242,12 @@ cc_library( name = "custom_call_sharding_helper", srcs = ["custom_call_sharding_helper.cc"], hdrs = ["custom_call_sharding_helper.h"], - deps = ["//xla/hlo/ir:hlo"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:logging", + ], ) tf_proto_library( diff --git a/third_party/xla/xla/service/custom_call_sharding_helper.cc b/third_party/xla/xla/service/custom_call_sharding_helper.cc index 8bc09a4ca79ae1..4492bf7c148993 100644 --- a/third_party/xla/xla/service/custom_call_sharding_helper.cc +++ b/third_party/xla/xla/service/custom_call_sharding_helper.cc @@ -19,6 +19,11 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/synchronization/mutex.h" +#include "tsl/platform/logging.h" + namespace xla { HloSharding CustomCallShardingHelper::PropagateUserSharding( @@ -53,10 +58,13 @@ GetPartitioners() { std::unique_ptr>; return *out; } + +ABSL_CONST_INIT absl::Mutex partitioners_mutex(absl::kConstInit); } // namespace const CustomCallPartitioner* GetCustomCallPartitioner( const std::string& custom_call_target) { + absl::MutexLock partitioners_lock(&partitioners_mutex); auto& partitioners = GetPartitioners(); auto it = partitioners.find(custom_call_target); if (it == partitioners.end()) { @@ -68,8 +76,16 @@ const CustomCallPartitioner* GetCustomCallPartitioner( void RegisterCustomCallPartitioner( const std::string& custom_call_target, std::unique_ptr partitioner) { + absl::MutexLock partitioners_lock(&partitioners_mutex); auto& partitioners = GetPartitioners(); - partitioners.emplace(custom_call_target, std::move(partitioner)); + // Warn if something has already been registered. We prefer to keep the + // existing object as other threads are more likely to observe it. + auto [it, did_insert] = + partitioners.try_emplace(custom_call_target, std::move(partitioner)); + if (!did_insert) { + LOG(ERROR) << "Failed to register custom call partitioner for " + << custom_call_target; + } } } // namespace xla diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2f80141ba54c92..f98586c922d9bd 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -935,7 +935,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/rocm:rocm_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -5877,7 +5877,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 2d53b5252478ba..61312127483ae1 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -183,7 +183,7 @@ xla_cc_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor/gpu:gpu_init", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index a55f5bc68de494..234649a31ea65b 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -393,7 +393,7 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/platform", "//xla/tsl/cuda:cudnn", "//xla/tsl/util:env_var", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 05ed621d316876..93c06c0533e39b 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -319,49 +319,67 @@ gpu_only_cc_library( ) gpu_only_cc_library( - name = "gpu_timer_kernel_header", - hdrs = ["gpu_timer_kernel.h"], + name = "gpu_semaphore", + srcs = ["gpu_semaphore.cc"], + hdrs = ["gpu_semaphore.h"], + deps = [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream_executor_headers", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], ) gpu_kernel_library( - name = "gpu_timer_kernel", - srcs = if_gpu_is_configured(["gpu_timer_kernel.cu.cc"]), + name = "gpu_timer_kernel_cuda", + srcs = [ + "gpu_timer_kernel.h", + "gpu_timer_kernel_cuda.cu.cc", + ], + tags = ["manual"], deps = [ - ":gpu_timer_kernel_header", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), + ":gpu_driver_header", + ":gpu_executor_header", + ":gpu_semaphore", + ":gpu_stream", + "//xla/stream_executor", + "@com_google_absl//absl/status:statusor", + ], ) -gpu_only_cc_library( - name = "gpu_timer_header", - hdrs = ["gpu_timer.h"], +cc_library( + name = "gpu_timer_kernel_rocm", + srcs = [ + "gpu_timer_kernel.h", + "gpu_timer_kernel_rocm.cc", + ], + tags = ["manual"], deps = [ - ":gpu_executor_header", - ":gpu_timer_kernel_header", - ":gpu_types_header", + ":gpu_semaphore", + ":gpu_stream", + "//xla/stream_executor", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/time", ], ) gpu_only_cc_library( name = "gpu_timer", - srcs = ["gpu_timer.cc"], - hdrs = ["gpu_timer.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + srcs = [ + "gpu_timer.cc", + "gpu_timer_kernel.h", + ], + hdrs = [ + "gpu_timer.h", + ], deps = [ ":gpu_driver_header", ":gpu_executor_header", + ":gpu_semaphore", ":gpu_stream", - ":gpu_timer_kernel_header", ":gpu_types_header", "//xla/stream_executor", - "//xla/stream_executor:stream_executor_interface", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -371,14 +389,11 @@ gpu_only_cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/utility", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ] + if_gpu_is_configured([ - ":gpu_timer_kernel", - ]) + if_cuda_is_configured([ - "//xla/stream_executor/cuda:cuda_driver", + ] + if_cuda_is_configured([ + ":gpu_timer_kernel_cuda", ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_driver", + ":gpu_timer_kernel_rocm", ]), ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.cc b/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.cc new file mode 100644 index 00000000000000..d5e2706135d03c --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.cc @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_semaphore.h" + +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor { +absl::StatusOr GpuSemaphore::Create(StreamExecutor* executor) { + // Allocate the value in pinned host memory that can be read from both + // host and device. + TF_ASSIGN_OR_RETURN(auto alloc, + executor->HostMemoryAllocate(sizeof(GpuSemaphoreState))); + return GpuSemaphore{std::move(alloc)}; +} + +DeviceMemory GpuSemaphore::device() { + // This assumes unified addressing, as we do not explicitly translate the + // host pointer into a device pointer. + return DeviceMemory::MakeFromByteSize( + ptr_->opaque(), sizeof(GpuSemaphoreState)); +} +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.h b/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.h new file mode 100644 index 00000000000000..4436c97014d8a8 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_semaphore.h @@ -0,0 +1,56 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_SEMAPHORE_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_SEMAPHORE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/stream_executor.h" + +namespace stream_executor { +enum struct GpuSemaphoreState { kHold, kRelease, kTimedOut }; + +// A basic semaphore that allows synchronization between host and GPU. +// It uses pinned host memory as the communication channel. +class GpuSemaphore { + public: + // Creates an invalid semaphore instance + GpuSemaphore() = default; + + // Creates a valid semaphore. Allocates some pinned host memory using + // `executor`. + static absl::StatusOr Create(StreamExecutor* executor); + + // Returns true if this semaphore is valid, otherwise false. + explicit operator bool() const { return bool{ptr_}; } + + GpuSemaphoreState& operator*() { + return *static_cast(ptr_->opaque()); + } + DeviceMemory device(); + + private: + explicit GpuSemaphore(std::unique_ptr alloc) + : ptr_{std::move(alloc)} {} + std::unique_ptr ptr_; +}; +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_SEMAPHORE_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc index e5928e7a7cc782..433424ab84b0bd 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc @@ -16,8 +16,10 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_timer.h" #include +#include #include #include +#include #include #include "absl/base/const_init.h" @@ -31,8 +33,11 @@ limitations under the License. #include "absl/utility/utility.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/gpu/gpu_timer_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -92,25 +97,8 @@ GpuTimer::CreateIfNeeded(GpuStream* stream, bool is_needed) { return std::nullopt; } -/*static*/ absl::StatusOr -GpuTimer::GpuSemaphore::Create(StreamExecutor* executor) { - // Allocate the value in pinned host memory that can be read from both - // host and device. - TF_ASSIGN_OR_RETURN(auto alloc, - executor->HostMemoryAllocate(sizeof(GpuSemaphoreState))); - return GpuSemaphore{std::move(alloc)}; -} - -DeviceMemory GpuTimer::GpuSemaphore::device() { - // This assumes unified addressing, as we do not explicitly translate the - // host pointer into a device pointer. - return DeviceMemory::MakeFromByteSize( - ptr_->opaque(), sizeof(GpuSemaphoreState)); -} - /*static*/ absl::StatusOr GpuTimer::Create(Stream* real_stream, bool use_delay_kernel) { - StreamExecutor* executor = real_stream->parent(); GpuStream* stream = AsGpuStream(real_stream); GpuExecutor* parent = stream->parent(); GpuContext* context = parent->gpu_context(); @@ -126,37 +114,15 @@ DeviceMemory GpuTimer::GpuSemaphore::device() { LOG(WARNING) << "Skipping the delay kernel, measurement accuracy will be reduced"; } -#ifdef GOOGLE_CUDA + if (use_delay_kernel && ShouldLaunchDelayKernel()) { - // Check the assumption that this device supports unified addressing, - // otherwise skip the delay kernel - TF_ASSIGN_OR_RETURN(int status, GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, - parent->device())); - if (!status) { - LOG(WARNING) << "Skipping the delay kernel because the device does not " - "support unified addressing"; - } else { - // Allocate a semaphore value that will be used to signal to the delay - // kernel that it may exit. - TF_ASSIGN_OR_RETURN(semaphore, GpuSemaphore::Create(executor)); - *semaphore = GpuSemaphoreState::Hold; - // In principle the kernel could be loaded lazily and shared across - // multiple GpuTimer objects. - TF_ASSIGN_OR_RETURN( - auto kernel, - (TypedKernel, - GpuSemaphoreState>::Create(executor, "DelayKernel", - delay_kernel::kernel()))); - // Launch a delay kernel into this stream, which will spin until - // GetElapsedDuration() is called, the timer is destroyed, or the timeout - // in the kernel is reached. - TF_RETURN_IF_ERROR(real_stream->ThenLaunch( - ThreadDim(1, 1, 1), BlockDim(1, 1, 1), kernel, semaphore.device(), - GpuSemaphoreState::Release)); + TF_ASSIGN_OR_RETURN(bool is_supported, DelayKernelIsSupported(stream)); + + if (is_supported) { + TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(real_stream)); } } -#endif // GOOGLE_CUDA + // The start event goes after the delay kernel in the stream TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent->gpu_context(), start_event, stream->gpu_stream())); @@ -181,7 +147,7 @@ GpuTimer::~GpuTimer() { GpuContext* context = parent_->gpu_context(); if (semaphore_ && !is_stopped_) { // Signal the delay kernel that it can exit - *semaphore_ = GpuSemaphoreState::Release; + *semaphore_ = GpuSemaphoreState::kRelease; // Wait for the delay kernel to exit before destroying the value that it is // watching. absl::Status status = @@ -212,14 +178,14 @@ absl::StatusOr GpuTimer::GetElapsedDuration() { stream_->gpu_stream())); // If we launched the delay kernel then check if it already timed out. if (semaphore_) { - if (*semaphore_ == GpuSemaphoreState::TimedOut) { + if (*semaphore_ == GpuSemaphoreState::kTimedOut) { // The delay kernel did not achieve the intended result. LOG(ERROR) << "Delay kernel timed out: measured time has sub-optimal " "accuracy. There may be a missing warmup execution, please " "investigate in Nsight Systems."; } else { // Signal that the kernel can exit - *semaphore_ = GpuSemaphoreState::Release; + *semaphore_ = GpuSemaphoreState::kRelease; } } float elapsed_milliseconds = NAN; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index 97124e611b4456..1cba7bfb9914ef 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ -#include #include #include #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { @@ -46,21 +46,6 @@ class GpuStream; // to be measured more accurately. class GpuTimer { public: - class GpuSemaphore { - public: - GpuSemaphore() = default; - static absl::StatusOr Create(StreamExecutor* executor); - explicit operator bool() const { return bool{ptr_}; } - GpuSemaphoreState& operator*() { - return *static_cast(ptr_->opaque()); - } - DeviceMemory device(); - - private: - explicit GpuSemaphore(std::unique_ptr alloc) - : ptr_{std::move(alloc)} {} - std::unique_ptr ptr_; - }; static absl::StatusOr Create(Stream* stream, bool use_delay_kernel); [[deprecated("Pass Stream* not GpuStream*")]] static absl::StatusOr Create(GpuStream* stream); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h index 2ac358b4ee56c5..cb0c5d1a3ccff3 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.h @@ -16,11 +16,21 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" + namespace stream_executor::gpu { -enum struct GpuSemaphoreState { Hold, Release, TimedOut }; -namespace delay_kernel { -void* kernel(); // returns a pointer to a CUDA C++ device function -} // namespace delay_kernel +// Returns true if the current backend and GPU supports the delay kernel for +// time measurement. It might return an error if checking for the support at +// runtime failed. +absl::StatusOr DelayKernelIsSupported(GpuStream* stream); + +// Launches the delay kernel on the given stream. The caller is responsible for +// keeping the returned semaphore alive until the kernel finished executing. +// Setting the semaphore to `kRelease` makes the kernel quit. +absl::StatusOr LaunchDelayKernel(Stream* stream); } // namespace stream_executor::gpu #endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc similarity index 53% rename from third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc rename to third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc index 0ce4b1d9fbb323..b4af6019234d04 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc @@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/gpu/gpu_timer_kernel.h" #include +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/gpu/gpu_timer_kernel.h" + namespace stream_executor::gpu { namespace { // Wait for the value pointed to by `semaphore` to have value `target`, timing @@ -40,11 +44,48 @@ __global__ void DelayKernel(volatile GpuSemaphoreState* semaphore, if (target_not_reached) { // We are exiting due to the timeout. Signal this back to the host so that // we can emit a warning, as it probably indicates suboptimal usage. - *semaphore = GpuSemaphoreState::TimedOut; + *semaphore = GpuSemaphoreState::kTimedOut; } } } // namespace +absl::StatusOr LaunchDelayKernel(Stream* stream) { + StreamExecutor* executor = stream->parent(); + + // Allocate a semaphore value that will be used to signal to the delay + // kernel that it may exit. + TF_ASSIGN_OR_RETURN(auto semaphore, GpuSemaphore::Create(executor)); + *semaphore = GpuSemaphoreState::kHold; + // In principle the kernel could be loaded lazily and shared across + // multiple GpuTimer objects. + TF_ASSIGN_OR_RETURN( + auto kernel, + (TypedKernel, GpuSemaphoreState>::Create( + executor, "DelayKernel", reinterpret_cast(DelayKernel)))); + // Launch a delay kernel into this stream, which will spin until + // GetElapsedDuration() is called, the timer is destroyed, or the timeout + // in the kernel is reached. + TF_RETURN_IF_ERROR(stream->ThenLaunch(ThreadDim(1, 1, 1), BlockDim(1, 1, 1), + kernel, semaphore.device(), + GpuSemaphoreState::kRelease)); + + return semaphore; +} + +absl::StatusOr DelayKernelIsSupported(GpuStream* stream) { + // Check the assumption that this device supports unified addressing, + // otherwise skip the delay kernel + TF_ASSIGN_OR_RETURN(int status, GpuDriver::GetDeviceAttribute( + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, + stream->parent()->device())); + if (!status) { + LOG(WARNING) << "Skipping the delay kernel because the device does not " + "support unified addressing"; + } + + return static_cast(status); +} + namespace delay_kernel { void* kernel() { return reinterpret_cast(DelayKernel); } } // namespace delay_kernel diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc new file mode 100644 index 00000000000000..2ee3680fa3f757 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer_kernel_rocm.cc @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" + +namespace stream_executor::gpu { + +absl::StatusOr DelayKernelIsSupported(GpuStream*) { return false; } + +absl::StatusOr LaunchDelayKernel(Stream* stream) { + return absl::UnimplementedError("Not implemented"); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 71ad43a60c0dcc..a629c4efbbbd68 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -256,7 +256,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_activation", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/platform", "//xla/stream_executor:blas", @@ -342,7 +342,7 @@ cc_library( "//xla/stream_executor:plugin_registry", "//xla/stream_executor/gpu:gpu_activation_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer_header", + "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/platform",