Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] make bytecode of resume function resemble natural bytecode #126630

Closed
wants to merge 5 commits into from

Conversation

youkaichao
Copy link
Collaborator

@youkaichao youkaichao commented May 18, 2024

Copy link

pytorch-bot bot commented May 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126630

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit b7ad5f3 with merge base a8195f2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@youkaichao
Copy link
Collaborator Author

When we deal with natural bytecode (Python bytecode that is generated by Python compiler, from source code), we can observe that co_freevars seem to be sorted, no matter how they are declared:

def helper_outer_function():
    data_ptr = None
    last_dim_size = None
    last_two_dims_size = None
    shape = None
    stride = None
    qkv_format = None
    def dummy1():
        nonlocal data_ptr, last_dim_size, last_two_dims_size, shape, stride, qkv_format
        # Access variables in the order they were declared
        _ = data_ptr
        _ = last_dim_size
        _ = last_two_dims_size
        _ = shape
        _ = stride
        _ = qkv_format
        print(data_ptr)
        print(last_dim_size)
        print(last_two_dims_size)
        print(shape)
        print(stride)
        print(qkv_format)
    dummy1()

print(helper_outer_function.__code__.co_consts[1].co_freevars)

The output is:

('data_ptr',
 'last_dim_size',
 'last_two_dims_size',
 'qkv_format',
 'shape',
 'stride')

Note that our declaration order is data_ptr, last_dim_size, last_two_dims_size, shape, stride, qkv_format , but Python sorts them.

If Dynamo generated bytecode does not obey this rule, it means we cannot generate source code that can compile the same bytecode as Dynamo, which makes it impossible to understand Dynamo bytecode by decompiling it into source code.

@youkaichao
Copy link
Collaborator Author

Note: we require the new bytecode has exactly the same co_freevar as the old one, in order to have faster new frame construction. It is proposed by me in #115062 .

@youkaichao
Copy link
Collaborator Author

reference:

https://github.com/python/cpython/blob/caf6064a1bc15ac344afd78b780188e60b9c628e/Python/compile.c#L530-L534

/* Sort the keys so that we have a deterministic order on the indexes
   saved in the returned dictionary.  These indexes are used as indexes
   into the free and cell var storage.  Therefore if they aren't
   deterministic, then the generated bytecode is not deterministic.
*/

indexes of free and cell var storage are sorted.

@youkaichao
Copy link
Collaborator Author

TODO: how to add test for this. We need to generate a resume function with freevars.

@ezyang
Copy link
Contributor

ezyang commented May 21, 2024

Trying @williamwen42 as reviewer, but this looks pretty harmless, shout if it gets lost

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 21, 2024
Copy link
Member

@williamwen42 williamwen42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - just add a test

@youkaichao
Copy link
Collaborator Author

@williamwen42 thanks for the reply. Can you give me some guidance on how to add tests? In particular, I don't know how to manually instruct Dynamo to generate a resume function as I want.

@williamwen42
Copy link
Member

You can compile a function with a graph break and then search globals/locals for the resume function:

import torch

def fn(x):
    x = x + 1
    torch._dynamo.graph_break()
    x = torch.sin(x)
    return x

opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(10))

for k, v in list(globals().items()):
    if k.startswith("__resume_at"):
        print(k)
        print(v)

@youkaichao
Copy link
Collaborator Author

This kind of resume function does not have freevars:

__resume_at_16_3.__code__.co_freevars == ()

@williamwen42
Copy link
Member

You can use a different function that has freevars - compile a function with a closure?

@youkaichao
Copy link
Collaborator Author

Still no freevars:

import torch

def fn(x):
    x = x + 1
    @torch.compile(backend="eager")
    def inner(x):
        x = x + 1
        torch._dynamo.graph_break()
        x = x * 2
        return x
    y = inner(torch.sin(x))
    return y

fn(torch.randn(10))

for k, v in list(globals().items()):
    if k.startswith("__resume_at"):
        print(k)
        print(v)

@youkaichao
Copy link
Collaborator Author

When will resume function contain free vars?

@williamwen42
Copy link
Member

williamwen42 commented May 21, 2024

Looks like resume functions with freevars are not exposed to the global scope, so we'll need to add something like

            # expose code object for debugging purposes
            self.output.install_global_unsafe(name, new_code)

before the cg.make_function_with_closure(name, new_code, True, stack_len) line in def create_call_resume_at (symbolic_convert.py).

Then a function like this should work:

import torch

def create():
    cl = 1
    def fn(x):
        x = x + 1
        torch._dynamo.graph_break()
        x = x + cl
        return x
    return fn

fn = create()
opt_fn = torch.compile(fn, backend="eager")
print(opt_fn(torch.randn(10)))

breakpoint()
for k, v in list(globals().items()):
    if k.startswith("__resume_at"):
        print(k)
        print(v)
        print(v.co_freevars)
        print(v.co_cellvars)

@youkaichao
Copy link
Collaborator Author

@williamwen42 thanks for the guidance! do you know if the test failures in the commit are related with the changes in this PR?

@williamwen42
Copy link
Member

They don't look related, but we'll see what CI shows.

@youkaichao
Copy link
Collaborator Author

@williamwen42 can you take a look at whether ci test failures are related?

@williamwen42
Copy link
Member

They don't look related - we're having some issues with CI atm.

@youkaichao
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 22, 2024
@youkaichao youkaichao added the topic: not user facing topic category label May 22, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@youkaichao
Copy link
Collaborator Author

@williamwen42 then how can we merge this? do we need to wait until the ci team fixes the issue?

@huydhn
Copy link
Contributor

huydhn commented May 22, 2024

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased youkaichao-patch-1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout youkaichao-patch-1 && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@youkaichao youkaichao deleted the youkaichao-patch-1 branch May 23, 2024 05:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo oncall: pt2 open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants