Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce time complexity of node replacment in PyTorch frontend #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

marty1885
Copy link

@marty1885 marty1885 commented Mar 26, 2024

Hi,

This is a part of the joined work with @JushBJJ and @JonathanALevine to get LLMs running in BUDA (and the bounties). This PR reduces the time complexity of the node replacement step. This makes time to compile large models tolerable. And pretty much required when compiling really large models like RWKVv4 which has 600K nodes.

Ref. tenstorrent/tt-buda-demos#22

The old code loops through all nodes for each node it wants to replace. Which makes it an O(N^2) algorithm. Now it uses a hash map and reduces the complexity down to O(N).

@marty1885 marty1885 changed the title Teduce time complexity of node replacment in PyTorch frontend Reduce time complexity of node replacment in PyTorch frontend Mar 26, 2024
@@ -5871,7 +5871,19 @@ def export_c_graph(location, graph):
fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt")
with open(f"{fname}", "w") as f:
f.write(str(graph))


def _binray_search(lst, func):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct spelling to _binary_search(lst, func):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding it. Fixed

python/tvm/relay/frontend/pytorch.py Show resolved Hide resolved

relevant_ops = node_inputs_map[orig_node]
begin_idx = _binary_search(relevant_ops, lambda x: x[0] > node_idx)
for idx, node in relevant_ops[begin_idx:]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for idx, can you replace it with _?

@@ -5871,7 +5871,19 @@ def export_c_graph(location, graph):
fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt")
with open(f"{fname}", "w") as f:
f.write(str(graph))


def _binary_search(lst, func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calling arg name func, let's use condition for easier understanding

@marty1885
Copy link
Author

@nvukobratTT Poke on this. I've applied the requested changes.

@nvukobratTT
Copy link
Contributor

@nvukobratTT Poke on this. I've applied the requested changes.

Hey @marty1885, sorry for the delayed response. We have to do a few internal checks before final approval and merge. As soon as I get those I'll approve this MR :)

Until then, freely continue with model bringup with these changes cherry-picked. I hope this isn't causing too many issues on your end.

@marty1885
Copy link
Author

@nvukobratTT Poke again. Is there anything I can do to push forward the PR? I feel this patch would be helpful to every user of BUDA and wants it to be included in future releases. I understand your schedule is busy. More then a month to merge a widely beneficial patch seems a bit long.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants