Skip to content

Commit

Permalink
[XLA] Fix race condition in custom call partitioner registration.
Browse files Browse the repository at this point in the history
The custom call partitioner registration is not thread-safe. This can lead to a race condition when multiple threads try to register the same partitioner. This CL fixes the race condition by adding a mutex to protect the registration process.

FUTURE_COPYBARA_INTEGRATE_REVIEW=#62750 from mattbahr:implement-sampled-addmm-v2 c295a0e
PiperOrigin-RevId: 630156374
  • Loading branch information
majnemer authored and tensorflower-gardener committed May 2, 2024
1 parent 69b7c1c commit 153c351
Show file tree
Hide file tree
Showing 18 changed files with 587 additions and 109 deletions.
1 change: 1 addition & 0 deletions tensorflow/python/ops/BUILD
Expand Up @@ -2072,6 +2072,7 @@ py_strict_library(
":array_ops_stack",
":bitwise_ops_gen",
":data_flow_ops_gen",
":logging_ops_gen",
":math_ops_gen",
":nn_ops_gen",
":sparse_ops_gen",
Expand Down
195 changes: 195 additions & 0 deletions tensorflow/python/ops/math_ops.py
Expand Up @@ -87,6 +87,7 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_sparse_ops
Expand Down Expand Up @@ -4678,6 +4679,200 @@ def sparse_segment_sum(
)


@tf_export("sparse.sampled_addmm", v1=[])
def sampled_addmm(
indices,
values,
dense_shape,
mat1,
mat2,
beta=1.0,
alpha=1.0,
output_type=dtypes.float32,
):
"""Performs the sampled matrix multiplication of two dense matrices.
Multiplies matrix `mat1` by matrix `mat2` at the locations defined by
`indices`. The product is scaled and added to `values`,
producing `alpha` * (`mat1` @ `mat2`) * spy(`indices`) + `beta` * `values`.
The function `spy(indices)` is the sparsity pattern matrix derived from
`indices`.
The `mat1` and `mat2` inputs must be tensors of rank >= 2 where the inner 2
dimensions specify valid matrix multiplication dimensions, and any further
dimensions specify matching batch size.
The `indices`, `values`, and `dense_shape` inputs make up the components of a
`SparseTensor` which defines the sparsity pattern of the output. The sparsity
pattern has values of 1 at the positions defined by the `SparseTensor`, and 0
elsewhere.
The `alpha` and `beta` inputs are the scaling factors.
The supported types for `values`, `mat1`, and `mat2` are:
`bfloat16`, `float16`, `float32`, `float64`.
A simple 2-D tensor operation:
>>> indices = tf.constant([0, 0, 1, 1], shape=[2, 2])
>>> indices
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 0],
[1, 1]], dtype=int32)>
>>> values = tf.constant([0.5, 0.3])
>>> values
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.5, 0.3], dtype=float32)>
>>> dense_shape = tf.constant([2, 2])
>>> dense_shape
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
>>> mat1 = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=tf.float32)
>>> mat1
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)>
>>> mat2 = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=tf.float32)
>>> mat2
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[ 7., 8.],
[ 9., 10.],
[11., 12.]], dtype=float32)>
>>> tf.sparse.sampled_addmm(indices, values, dense_shape, mat1, mat2,
... alpha=0.75, beta=0.25)
(<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 0],
[1, 1]], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=float32, numpy=
array([ 43.625, 115.575], dtype=float32)>,
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>)
A batch operation:
>>> indices = tf.constant([0, 1, 1, 0, 0, 0, 1, 0], shape=[2, 2, 2])
>>> indices
<tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
array([[[0, 1],
[1, 0]],
[[0, 0],
[1, 0]]], dtype=int32)>
>>> values = tf.constant([3, 5, 2, 7], shape=[2, 2], dtype=tf.float32)
>>> values
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[3., 5.],
[2., 7.]], dtype=float32)>
>>> dense_shape = tf.constant([2, 2])
>>> dense_shape
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>
>>> mat1 = tf.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=tf.float32)
>>> mat1
<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[ 1., 2., 3.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[10., 11., 12.]]], dtype=float32)>
>>> mat2 = tf.constant(np.arange(13, 25), shape=[2, 3, 2], dtype=tf.float32)
>>> mat2
<tf.Tensor: shape=(2, 3, 2), dtype=float32, numpy=
array([[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]], dtype=float32)>
>>> tf.sparse.sampled_addmm(indices, values, dense_shape, mat1, mat2,
... alpha=0.75, beta=0.25)
(<tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
array([[[0, 1],
[1, 0]],
[[0, 0],
[1, 0]]], dtype=int32)>, <tf.Tensor: shape=(2, 2), dtype=float32,
numpy=array([[ 75.75, 173. ],
[381.5 , 524.5 ]], dtype=float32)>,
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 2], dtype=int32)>)
Args:
indices: `tf.Tensor` containing coordinates for the rows and columns to be
multiplied. Must have rank > 1.
values: `tf.Tensor` containing the values to be scaled and added to the
sampled dot product.
dense_shape: `tf.Tensor` defining the dense shape of the output.
mat1: `tf.Tensor` to be multiplied. Must have rank > 1.
mat2: `tf.Tensor` to be multiplied. Must have rank > 1.
beta: Number to be multipled with `values`. Defaults to 1.0.
alpha: Number to be multiplied with the sampled dot product of `mat1` and
`mat2`. Defaults to 1.0.
output_type: The output datatype if needed. Defaults to float32.
Returns:
A tuple representing the `SparseTensor` components of the result of the
operation.
Raises:
ValueError: If `dense_shape` does not match the shape of the product.
"""
indices = ops.convert_to_tensor(indices)
values = ops.convert_to_tensor(values, dtype=output_type)
dense_shape = ops.convert_to_tensor(dense_shape, dtype=dtypes.int32)
mat1 = ops.convert_to_tensor(mat1, dtype=output_type)
mat2 = ops.convert_to_tensor(mat2, dtype=output_type)

mat1_shape = tensor_util.constant_value(array_ops.shape(mat1))
mat2_shape = tensor_util.constant_value(array_ops.shape(mat2))

dense_rows = mat1_shape[-2]
dense_cols = mat2_shape[-1]

output_shape = array_ops_stack.stack([dense_rows, dense_cols])
condition = reduce_all(equal(dense_shape, output_shape))

# Use dense_shape to validate input matrix shapes.
if context.executing_eagerly():
if not condition:
raise ValueError(
f"Dense shape: {dense_shape} does not match "
f"output shape: {output_shape}"
)
else: # not context.executing_eagerly()
dense_shape_static = tensor_util.constant_value(dense_shape)
output_shape_static = tensor_util.constant_value(output_shape)
if dense_shape_static is not None and output_shape_static is not None:
condition_static = np.all(
np.equal(dense_shape_static, output_shape_static)
)
if not condition_static:
raise ValueError(
f"Dense shape: {dense_shape} does not match "
f"output shape: {output_shape}"
)

data = [
"Dense shape: ",
dense_shape,
" does not match output shape: ",
output_shape,
]

gen_logging_ops._assert(condition, data, None, name="Assert")

# Extract row and column indices.
batch_indices = indices[..., :-2]
row_indices = indices[..., :-1]
col_indices = array_ops.concat([batch_indices, indices[..., -1:]], axis=-1)

# Calculate batch dimensions.
rank = tensor_util.constant_value(array_ops.rank(mat1))
batch_dims = rank - 2

# Extract rows and columns.
rows = array_ops.gather_nd(mat1, row_indices, batch_dims=batch_dims)
cols = array_ops.gather_nd(
array_ops.matrix_transpose(mat2), col_indices, batch_dims=batch_dims
)

# Calculate dot product for the extracted rows and columns.
dot = reduce_sum(rows * cols, axis=-1)
return (indices, dot * alpha + values * beta, dense_shape)


@tf_export("sparse.segment_sum", v1=[])
def sparse_segment_sum_v2(
data,
Expand Down
114 changes: 114 additions & 0 deletions tensorflow/python/ops/math_ops_test.py
Expand Up @@ -372,6 +372,120 @@ def testInvalidOutputTypeMatmul(self):
self.evaluate(math_ops.matmul(a, b, output_type=dtypes.float32))


@test_util.run_all_in_graph_and_eager_modes
class SampledADDMMTest(test_util.TensorFlowTestCase):
"""Test for sampled_addmm."""

SUPPORTED_DTYPES = [
dtypes.bfloat16,
dtypes.float16,
dtypes.float32,
dtypes.float64,
]

def sampledADDMMRef(
self,
indices,
values,
dense_shape,
mat1,
mat2,
beta=1.0,
alpha=1.0,
output_type=dtypes.float32,
):
dense = math_ops.matmul(mat1, mat2, output_type=output_type)
dense_vals = array_ops.gather_nd(dense, indices, batch_dims=dense.ndim - 2)
return alpha * dense_vals + beta * values

def testSampledADDMM2D(self):
for dtype in self.SUPPORTED_DTYPES:
indices = constant_op.constant([[0, 0], [1, 1]])
values = constant_op.constant([0.5, 0.3], dtype=dtype)
dense_shape = constant_op.constant([2, 2])
mat1 = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
mat2 = constant_op.constant(
[7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=dtype
)
alpha = 0.75
beta = 0.25

_, res, _ = math_ops.sampled_addmm(
indices,
values,
dense_shape,
mat1,
mat2,
beta=beta,
alpha=alpha,
output_type=dtype,
)
ref = self.sampledADDMMRef(
indices,
values,
dense_shape,
mat1,
mat2,
beta=beta,
alpha=alpha,
output_type=dtype,
)
self.assertAllClose(res, ref, atol=1e-2)

def testBatchSampledADDMM(self):
for dtype in self.SUPPORTED_DTYPES:
indices = constant_op.constant([[[0, 1], [1, 0]], [[0, 0], [1, 0]]])
values = constant_op.constant([[3, 5], [2, 7]], dtype=dtype)
dense_shape = constant_op.constant([2, 2])
mat1 = constant_op.constant(
np.arange(1, 13), shape=[2, 2, 3], dtype=dtype
)
mat2 = constant_op.constant(
np.arange(13, 25), shape=[2, 3, 2], dtype=dtype
)
alpha = 0.4
beta = 0.6

_, res, _ = math_ops.sampled_addmm(
indices,
values,
dense_shape,
mat1,
mat2,
beta=beta,
alpha=alpha,
output_type=dtype,
)
ref = self.sampledADDMMRef(
indices,
values,
dense_shape,
mat1,
mat2,
beta=beta,
alpha=alpha,
output_type=dtype,
)
self.assertAllClose(res, ref, atol=1e-2)

def testInvalidDenseShape(self):
for dtype in self.SUPPORTED_DTYPES:
indices = constant_op.constant([[[0, 1], [1, 0]], [[0, 0], [1, 0]]])
values = constant_op.constant([[3, 5], [2, 7]], dtype=dtype)
dense_shape = constant_op.constant([1, 2])
mat1 = constant_op.constant(
np.arange(1, 13), shape=[2, 2, 3], dtype=dtype
)
mat2 = constant_op.constant(
np.arange(13, 25), shape=[2, 3, 2], dtype=dtype
)

with self.assertRaisesRegex(ValueError, "does not match output shape"):
math_ops.sampled_addmm(
indices, values, dense_shape, mat1, mat2, output_type=dtype
)


@test_util.run_all_in_graph_and_eager_modes
class ModTest(test_util.TensorFlowTestCase):

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
Expand Up @@ -80,6 +80,10 @@ tf_module {
name: "retain"
argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "sampled_addmm"
argspec: "args=[\'indices\', \'values\', \'dense_shape\', \'mat1\', \'mat2\', \'beta\', \'alpha\', \'output_type\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "segment_mean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
Expand Down
7 changes: 6 additions & 1 deletion third_party/xla/xla/service/BUILD
Expand Up @@ -7242,7 +7242,12 @@ cc_library(
name = "custom_call_sharding_helper",
srcs = ["custom_call_sharding_helper.cc"],
hdrs = ["custom_call_sharding_helper.h"],
deps = ["//xla/hlo/ir:hlo"],
deps = [
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"@local_tsl//tsl/platform:logging",
],
)

tf_proto_library(
Expand Down

0 comments on commit 153c351

Please sign in to comment.