New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sampled addmm #62750
Implement sampled addmm #62750
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other comments
Hi @mattbahr Can you please check @cantonios's comments and keep us posted ? Thank you! |
@gbaned absolutely! |
…ated attribute calls with tf functions
😭😭😭😭😭
…On Sat, Jan 27, 2024, 1:39 AM Matt Bahr ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In tensorflow/python/ops/math_ops.py
<#62750 (comment)>
:
> + mat1 = ops.convert_to_tensor(mat1)
+ if not isinstance(mat2, tensor_lib.Tensor):
+ mat2 = ops.convert_to_tensor(mat2)
+
+ if values.dtype != output_type:
+ values = cast(values, output_type)
+ if mat1.dtype != output_type:
+ mat1 = cast(mat1, output_type)
+ if mat2.dtype != output_type:
+ mat2 = cast(mat2, output_type)
+
+ dense_rows = mat1.shape[-2]
+ dense_cols = mat2.shape[-1]
+
+ # TODO(mattbahr): use dense_shape to validate the shapes of mat1 and mat2
+ dense_shape = constant_op.constant([dense_rows, dense_cols])
Importing check_ops introduces a circular dependency. I tried using a
regular assert, but I kept getting AssertionErrors when I run my unit
tests.
—
Reply to this email directly, view it on GitHub
<#62750 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/A5CIWVH2QNOI4J7IY4FRHPTYQSOJTAVCNFSM6AAAAABBQB7C26VHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMYTQNBWHA4TINBQGY>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
@cantonios When you get the chance, you mind taking a look at my solution for the dense shape validation? Appreciate it! |
0d88583
into
tensorflow:master
The custom call partitioner registration is not thread-safe. This can lead to a race condition when multiple threads try to register the same partitioner. This CL fixes the race condition by adding a mutex to protect the registration process. FUTURE_COPYBARA_INTEGRATE_REVIEW=#62750 from mattbahr:implement-sampled-addmm-v2 c295a0e PiperOrigin-RevId: 630156374
Relates to #56311