-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamo] Refactor get_wrapper and pickling compiled graph #438
Conversation
Simple model with one conv2d failed. - fix signature for conv* ops to corresponds torch.nn.functional] - add missed padding normalization After that the model works
Partial changes related to hidet-org#18
Add .vscode to .gitignore
Previously, if a performance regression fails due to an exception, the job that stops the runner VM instances will be skipped, leaving the instances on. This will make the stop_instances job run even when previous jobs failed. Not sure if always() will override the inputs.shutdown_instances flag, if it does we can move it into the step scope.
Module wrapper around groupnorm operator. Supports compiled app development.
9555421
to
4132253
Compare
Format Small typo
4132253
to
13f83e3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @destefy,
I don't like the idea of wrapping what we have returned as a torch.fx.GraphModule
. If you want to have something that can be saved, you can use CompiledGraph. Or make the CompiledForwardFunction
serializable.
use_cuda_graph = dynamo_config['use_cuda_graph'] | ||
if use_cuda_graph: | ||
try: | ||
runner = self.cgraph.cuda_graph() | ||
except CudaGraphCreationError: | ||
runner = self.cgraph | ||
else: | ||
runner = self.cgraph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not create the cuda graph for every call invoke.
try: | ||
runner = cgraph.cuda_graph() | ||
except CudaGraphCreationError: | ||
class CompiledForwardFunction(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the name TorchCompiledModule
The CentML compilation backend I am working on wants to wrap the
CompiledGraph
s forward function (the one returned byget_wrapper
) in atorch.fx.GraphModule
. ThisGraphModule
would then be pickled and sent from a server to a client.However, it isn't possible to pickle the lambda/local function returned by
get_wrapper
. Therefore, I am turningget_wrapper
into a classCompiledForwardFunction
whoseforward
function behaves like thewrapper
returned byget_wrapper
.Additionally, in order to pickle
CompiledForwardFunction
, I have defined pickling and unpickling behaviour forCompiledGraph
using__getstate__
and__setstate__
respectively. These just callCompiledGraph
's existingsave
andload
functions.