Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Refactor get_wrapper and pickling compiled graph #438

Closed
wants to merge 14 commits into from

Conversation

destefy
Copy link
Contributor

@destefy destefy commented Mar 13, 2024

The CentML compilation backend I am working on wants to wrap the CompiledGraphs forward function (the one returned by get_wrapper) in a torch.fx.GraphModule. This GraphModule would then be pickled and sent from a server to a client.

However, it isn't possible to pickle the lambda/local function returned by get_wrapper. Therefore, I am turning get_wrapper into a class CompiledForwardFunction whose forward function behaves like the wrapper returned by get_wrapper.

Additionally, in order to pickle CompiledForwardFunction, I have defined pickling and unpickling behaviour for CompiledGraph using __getstate__ and __setstate__ respectively. These just call CompiledGraph's existing save and load functions.

vadiklyutiy and others added 6 commits March 5, 2024 11:57
Simple model with one conv2d failed. 
- fix signature for conv* ops to corresponds torch.nn.functional]
- add missed padding normalization

After that the model works
Previously, if a performance regression fails due to an exception, the
job that stops the runner VM instances will be skipped, leaving the
instances on. This will make the stop_instances job run even when
previous jobs failed. Not sure if always() will override the
inputs.shutdown_instances flag, if it does we can move it into the step
scope.
Module wrapper around groupnorm operator. Supports compiled app
development.
@destefy destefy force-pushed the stefan/refactor_get_wrapper branch from 9555421 to 4132253 Compare March 13, 2024 19:38
@destefy destefy force-pushed the stefan/refactor_get_wrapper branch from 4132253 to 13f83e3 Compare March 13, 2024 19:40
Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @destefy,

I don't like the idea of wrapping what we have returned as a torch.fx.GraphModule. If you want to have something that can be saved, you can use CompiledGraph. Or make the CompiledForwardFunction serializable.

Comment on lines 134 to 141
use_cuda_graph = dynamo_config['use_cuda_graph']
if use_cuda_graph:
try:
runner = self.cgraph.cuda_graph()
except CudaGraphCreationError:
runner = self.cgraph
else:
runner = self.cgraph
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not create the cuda graph for every call invoke.

try:
runner = cgraph.cuda_graph()
except CudaGraphCreationError:
class CompiledForwardFunction(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the name TorchCompiledModule

@destefy destefy closed this Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants