Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: torch 2.0 #3682

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/bentoml/_internal/frameworks/pytorch.py
Expand Up @@ -5,18 +5,19 @@
from types import ModuleType
from typing import TYPE_CHECKING
from pathlib import Path
from functools import partial

import cloudpickle

import bentoml
from bentoml import Tag

from ..tag import Tag
from ..types import LazyType
from ..models import Model
from ..utils.pkg import get_pkg_version
from ...exceptions import NotFound
from ..models.model import ModelContext
from ..models.model import PartialKwargsModelOptions as ModelOptions
from ..models.model import PartialKwargsModelOptions
from .common.pytorch import torch
from .common.pytorch import PyTorchTensorContainer

Expand All @@ -43,9 +44,19 @@ def get(tag_like: str | Tag) -> Model:
return model


class ModelOptions(PartialKwargsModelOptions):
fullgraph: bool = False
dynamic: bool = False
backend: t.Union[str, t.Callable[..., t.Any]] = "inductor"
mode: t.Optional[str] = None
options: t.Optional[t.Dict[str, t.Union[str, int, bool]]] = None
disable: bool = False


Comment on lines +47 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, let’s call this PytorchOptions and only when import to bentoml.pytorch we rename it to ModelOptions.

Second, I think maybe it’s better to have single compile_kwargs dict in PytorchOptions instead of polluting the name space of PytorchOptions. Maybe we will have only 2 entries:

  • enable_compile
  • compile_kwargs

def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
**compile_kwargs: t.Any,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think we need do compile at load_model level. Let’s just save the original model and do torch.compile when init the runner (if user set enable_compile=True)

) -> torch.nn.Module:
"""
Load a model from a BentoML Model with given name.
Expand Down Expand Up @@ -76,13 +87,15 @@ def load_model(

weight_file = bentoml_model.path_of(MODEL_FILENAME)
with Path(weight_file).open("rb") as file:
model: "torch.nn.Module" = torch.load(file, map_location=device_id)
model: torch.nn.Module = torch.load(file, map_location=device_id)
if get_pkg_version("torch") >= "2.0.0":
return t.cast("torch.nn.Module", torch.compile(model, **compile_kwargs))
return model


def save_model(
name: str,
model: "torch.nn.Module",
model: torch.nn.Module,
*,
signatures: ModelSignaturesType | None = None,
labels: t.Dict[str, str] | None = None,
Expand Down Expand Up @@ -195,15 +208,25 @@ def get_runnable(bento_model: Model):
from .common.pytorch import PytorchModelRunnable
from .common.pytorch import make_pytorch_runnable_method

partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore
opts = t.cast(ModelOptions, bento_model.info.options)
if get_pkg_version("torch") >= "2.0.0":
_load_model = partial(
load_model,
fullgraph=opts.fullgraph,
dynamic=opts.dynamic,
backend=opts.backend,
mode=opts.mode,
options=opts.options,
disable=opts.disable,
)
else:
_load_model = load_model
Comment on lines +211 to +223
Copy link
Member

@larme larme Mar 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just model = load_model(…), and then torch.compile(model, **compile_kwargs) if enable_compile is True.


runnable_class: type[PytorchModelRunnable] = partial_class(
PytorchModelRunnable,
bento_model=bento_model,
loader=load_model,
runnable_class = partial_class(
PytorchModelRunnable, bento_model=bento_model, loader=_load_model
)
for method_name, options in bento_model.info.signatures.items():
method_partial_kwargs = partial_kwargs.get(method_name)
method_partial_kwargs = opts.partial_kwargs.get(method_name)
runnable_class.add_method(
make_pytorch_runnable_method(method_name, method_partial_kwargs),
name=method_name,
Expand Down