Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten._cdist_forward converter #2726

Merged
merged 5 commits into from May 18, 2024

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Apr 4, 2024

Description

New feature to support aten._cdist_forward converter converter. I have added test cases ensuring compatibility with both matching and broadcasting input shapes.

Fixes # (issue)

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 self-assigned this Apr 4, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 4, 2024
@chohk88 chohk88 linked an issue Apr 4, 2024 that may be closed by this pull request
py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py Outdated Show resolved Hide resolved
py/torch_tensorrt/dynamo/conversion/impl/linear.py Outdated Show resolved Hide resolved
x1: TRTTensor,
x2: TRTTensor,
p: float,
compute_mode: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that compute_mode is not used below. Is there any difficulty for this? Ideally, we should mimic all pytorch's behaviors (i.e., receiving the same args and then returning the same outputs as pytorch, based on the schema).
If compute_mode cannot be supported for some reasons, we need to use capability_validator to make sure this kind of cases will not be converted, instead, falling back to pytorch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_mode does not influence the output values. It is designed for optimizing computational efficiency in specific situations when calculating the Euclidean distance (p=2) for large-sized inputs.

I have modified to clarify:

  1. A note has been added that compute_mode doesn't alter the computational path in our implementation.
  2. A default value of 0 is used for compute_mode when it's unspecified, maintaining behavior consistency with Pytorch.
  3. I have included a warning for situations where an optimized path for p=2 might be expected but isn't utilized due to our unified approach (element-wise).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I'm still thinking if using matmul will be faster for the previous two scenarios? Considering that we have out-of-the-box matmul, if possible, could you compare their time costs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your valuable comment. In this update, I've implemented a matrix multiplication-based approach to compute the distance for p=2. During my tests on my local PC with inputs x1=(150,100,50,50) and x2=(150,100,30,50), I observed that the TRT run time improved by approximately 6.3 times. Here are the relevant logs from those tests:

INFO:harness:FX graph= graph():
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %_cdist_forward_default : [num_users=1] = call_function[target=torch.ops.aten._cdist_forward.default](args = (%x1, %x2, 2, 0), kwargs = {})
    return _cdist_forward_default
===============based on matmul : ==================
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.007999
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:255: DeprecationWarning: Use build_serialized_network instead.
  engine = self.builder.build_engine(self.ctx.net, builder_config)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:33.710717
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 244801024 bytes of Memory
INFO:harness:Interpreter run time(s): 33.719337898997765
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:63: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.input_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:66: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.output_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:88: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:93: DeprecationWarning: Use get_tensor_shape instead.
  tuple(self.engine.get_binding_shape(idx))
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:98: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:247: DeprecationWarning: Use set_input_shape instead.
  self.context.set_binding_shape(
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:262: DeprecationWarning: Use get_tensor_shape instead.
  shape = tuple(self.context.get_binding_shape(idx))
INFO:harness:TRT run time(s)= 0.0026535038948059084
.INFO:harness:FX graph= graph():
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %_cdist_forward_default : [num_users=1] = call_function[target=torch.ops.aten._cdist_forward.default](args = (%x1, %x2, 2, 1), kwargs = {})
    return _cdist_forward_default
===============based on elementwise pow for diff : ==================
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003120
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:255: DeprecationWarning: Use build_serialized_network instead.
  engine = self.builder.build_engine(self.ctx.net, builder_config)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:43.902498
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 4590000640 bytes of Memory
INFO:harness:Interpreter run time(s): 43.90606521600421
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:63: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.input_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:66: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.output_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:88: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:93: DeprecationWarning: Use get_tensor_shape instead.
  tuple(self.engine.get_binding_shape(idx))
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:98: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:247: DeprecationWarning: Use set_input_shape instead.
  self.context.set_binding_shape(
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:262: DeprecationWarning: Use get_tensor_shape instead.
  shape = tuple(self.context.get_binding_shape(idx))
INFO:harness:TRT run time(s)= 0.016830976486206056
.
----------------------------------------------------------------------

Copy link
Collaborator

@zewenli98 zewenli98 Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems matmul runs faster in both compilation stage and running stage, and saves a lot of memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. While always using matrix multiplication might seem more efficient considering TensorRT's capabilities, I've aligned the implementation with the compute mode of the ATen operations for consistency. Thank you for pointing this out!

tests/py/dynamo/conversion/test_cdist_aten.py Outdated Show resolved Hide resolved
print("x1 : ", x1)
print("x2 : ", x2)

return torch.ops.aten._cdist_forward.default(x1, x2, p, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different compute_modes other than None should be tested.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made modifications to allow testing for different compute_mode values from the input of test case.

tests/py/dynamo/conversion/test_cdist_aten.py Show resolved Hide resolved
tests/py/dynamo/conversion/test_cdist_aten.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good - added a few suggestions

x1=args[0],
x2=args[1],
p=args[2],
compute_mode=args[3],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the schema having int? as the type for compute_mode, this should be compute_mode = args_bounds_check(args, 2, None)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review! I've made the requested changes (compute_mode = args_bounds_check(args, 3, None))

x1: TRTTensor,
x2: TRTTensor,
p: float,
compute_mode: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to Optional[int], since None is a valid choice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment! I've made the requested changes.

Comment on lines 627 to 628
else:
raise NotImplementedError(f"Currently, p={p} is not implemented.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, since it seems you have covered all of the valid intervals by Torch's definition of cdist, in $[0,\infty]$, as per this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review! I've made the requested changes.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py	2024-04-29 14:43:08.889448+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py	2024-04-29 14:44:58.684466+00:00
@@ -549,11 +549,11 @@
    - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20].
      This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2.
    - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features),
      since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2.

-    Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, especially useful when working with large datasets. 
+    Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, especially useful when working with large datasets.
    This parameter allows you to control how the distances are computed, with different modes available to leverage matrix multiplication for speed improvements.
    """
    if compute_mode is None:
        compute_mode = 0

@zewenli98
Copy link
Collaborator

@chohk88 The CI failed. Can you take a look?

@zewenli98
Copy link
Collaborator

@chohk88 PR #2819 was merged to the main. Can you rebase and check if the error in CI is fixed?

@chohk88 chohk88 force-pushed the aten_cdist_forward_converter branch from 1962cf8 to 9f901f6 Compare May 9, 2024 04:33
Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 It seems the tests have too large dataset, which caused OOM error in CI. My advice is to turn them down but make sure every condition path would be covered.

@chohk88
Copy link
Collaborator Author

chohk88 commented May 11, 2024

@zewenli98 Thank you for your comments! The issue with CI/CD failures due to the dataset size has been resolved. However, while the Windows wheel build has completed successfully, the Linux wheel build is still failing.

@zewenli98
Copy link
Collaborator

@zewenli98 Thank you for your comments! The issue with CI/CD failures due to the dataset size has been resolved. However, while the Windows wheel build has completed successfully, the Linux wheel build is still failing.

It seems to work now. LGTM.

@zewenli98 zewenli98 merged commit 6fcccf1 into main May 18, 2024
50 of 51 checks passed
@zewenli98 zewenli98 deleted the aten_cdist_forward_converter branch May 18, 2024 01:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten._cdist_forward
5 participants