Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

support methods for torch.compiler.allow_in_graph #125244

Open
kilianyp opened this issue Apr 30, 2024 · 2 comments
Open

support methods for torch.compiler.allow_in_graph #125244

kilianyp opened this issue Apr 30, 2024 · 2 comments
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kilianyp
Copy link

kilianyp commented Apr 30, 2024

馃殌 The feature, motivation and pitch

this is a bug/feature report.

When using torch.dynamo.allow_in_graph on a method, it's not respected during compile.

import torch                                            
                                                                                          
def custom_backend(gm, _):
    gm.graph.print_tabular()                                                              
    return gm.forward


def test_allow_in_graph_fn():
    torch._dynamo.reset()

    @torch._dynamo.allow_in_graph
    def add_fn(a, b):
        return a + b

    class Model(torch.nn.Module):
        def forward(self, x):
            return add_fn(x, x)

    model = Model()
    model = torch.compile(model, backend=custom_backend)
    x = torch.rand(10)
    model(x)

This correctly creates a node for the add_fn

placeholder    l_x_    L_x_                                                                 ()            {}
call_function  add_fn  <function test_allow_in_graph_fn.<locals>.add_fn at 0x7f5c5646dea0>  (l_x_, l_x_)  {}
output         output  output                                                               ((add_fn,),)  {}
def test_allow_in_graph_class_fn():
    torch._dynamo.reset()

    class Foo:
        @torch._dynamo.allow_in_graph
        def add_fn(self, a, b):
            return a + b

    class Model(torch.nn.Module):
        def forward(self, x):
            foo = Foo()
            print(id(foo.add_fn))
            return foo.add_fn(x, x)

    print(id(Foo.add_fn))
    model = Model()
    model = torch.compile(model, backend=custom_backend)

This instead shows a node, for the built in add +

placeholder    l_a_    L_a_                     ()            {}
call_function  add     <built-in function add>  (l_a_, l_a_)  {}
output         output  output                   ((add,),)     {}

This is most likely due to the id being used by allow_in_graph, however, this changes in python after the instance is created.

The ideal solution would be if this was actually supported, alternatively I think some warning/error when allow_on_graph is used on a method could make sense.

Alternatives

torch.compiler.disable works is being respected, but it creates graph breaks which have been causing downstream issues for our custom backend so the allow_in_graph behaviour is preferred.

Additional context

No response

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@yanboliang
Copy link
Contributor

This is because the allow_in_graph mechanism is based on function id match, you are using different functions in the second case, so it won't be allowed in graph correctly.

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 30, 2024
@kilianyp
Copy link
Author

kilianyp commented May 2, 2024

This is because the allow_in_graph mechanism is based on function id match, you are using different functions in the second case, so it won't be allowed in graph correctly

Yes that's how I understood it as well.
The feature request was to either warn or raise an error if the decorator is used on a method (for example using inspect.ismethod, or even better support it by using a different mechanism than relying on the id of the function. Not sure which exist there, though.

To give some more context, our custom backend is supposed to ignore some specific functions defined by the user.
Currently, we are using torch.compiler.disable and torch.compiler.allow_in_graph to just exclude it from the backend.

Actually it would nice if we could differentiate between those decorators that are used to make the model dynamo traceable and the ones just passing information to the compiler backend.

Example:

def custom_backend(gm, values):
    gm.graph.print_tabular()
    return gm.forward

def test_decorators():

    class Model(torch.nn.Module):
        def own_fn(self, x):
            return x + x
        def forward(self, x):
            return self.own_fn(x)

    model = Model()
    x = torch.rand(10)
    model = torch.compile(model, backend=custom_backend)
    model(x)

We want to be able to indicate, that own_fn should be ignored by the custom backend.
allow_in_graph would be one option, but doesn't work because it's a method.

torch.compiler.disable works, but for us lead to some side-effects, and it seems like it's misusing this decorator, because the code itself dynamo traces fine.. it should just be ignored by the backend.

So a decorator/context manager that still traces the code but sets some information in the node data would be useful.

class Model(torch.nn.Module):
	@dynamo.mark("custom_attribute", True)
	def own_fn(self, x):
		return x + x
	 def forward(self, x):
            return self.own_fn(x)

which would add an additional entry in node.meta custom_attribute with the value True.

Alternatively, one could pass the function name to the custom backend, but then one would have to recover it from the stack_trace which is in string format. So that would be a bit more prone to errors IMO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants