New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: torch 2.0 #3682
base: main
Are you sure you want to change the base?
feat: torch 2.0 #3682
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,18 +5,19 @@ | |
from types import ModuleType | ||
from typing import TYPE_CHECKING | ||
from pathlib import Path | ||
from functools import partial | ||
|
||
import cloudpickle | ||
|
||
import bentoml | ||
from bentoml import Tag | ||
|
||
from ..tag import Tag | ||
from ..types import LazyType | ||
from ..models import Model | ||
from ..utils.pkg import get_pkg_version | ||
from ...exceptions import NotFound | ||
from ..models.model import ModelContext | ||
from ..models.model import PartialKwargsModelOptions as ModelOptions | ||
from ..models.model import PartialKwargsModelOptions | ||
from .common.pytorch import torch | ||
from .common.pytorch import PyTorchTensorContainer | ||
|
||
|
@@ -43,9 +44,19 @@ def get(tag_like: str | Tag) -> Model: | |
return model | ||
|
||
|
||
class ModelOptions(PartialKwargsModelOptions): | ||
fullgraph: bool = False | ||
dynamic: bool = False | ||
backend: t.Union[str, t.Callable[..., t.Any]] = "inductor" | ||
mode: t.Optional[str] = None | ||
options: t.Optional[t.Dict[str, t.Union[str, int, bool]]] = None | ||
disable: bool = False | ||
|
||
|
||
def load_model( | ||
bentoml_model: str | Tag | Model, | ||
device_id: t.Optional[str] = "cpu", | ||
**compile_kwargs: t.Any, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think we need do compile at |
||
) -> torch.nn.Module: | ||
""" | ||
Load a model from a BentoML Model with given name. | ||
|
@@ -76,13 +87,15 @@ def load_model( | |
|
||
weight_file = bentoml_model.path_of(MODEL_FILENAME) | ||
with Path(weight_file).open("rb") as file: | ||
model: "torch.nn.Module" = torch.load(file, map_location=device_id) | ||
model: torch.nn.Module = torch.load(file, map_location=device_id) | ||
if get_pkg_version("torch") >= "2.0.0": | ||
return t.cast("torch.nn.Module", torch.compile(model, **compile_kwargs)) | ||
return model | ||
|
||
|
||
def save_model( | ||
name: str, | ||
model: "torch.nn.Module", | ||
model: torch.nn.Module, | ||
*, | ||
signatures: ModelSignaturesType | None = None, | ||
labels: t.Dict[str, str] | None = None, | ||
|
@@ -195,15 +208,25 @@ def get_runnable(bento_model: Model): | |
from .common.pytorch import PytorchModelRunnable | ||
from .common.pytorch import make_pytorch_runnable_method | ||
|
||
partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore | ||
opts = t.cast(ModelOptions, bento_model.info.options) | ||
if get_pkg_version("torch") >= "2.0.0": | ||
_load_model = partial( | ||
load_model, | ||
fullgraph=opts.fullgraph, | ||
dynamic=opts.dynamic, | ||
backend=opts.backend, | ||
mode=opts.mode, | ||
options=opts.options, | ||
disable=opts.disable, | ||
) | ||
else: | ||
_load_model = load_model | ||
Comment on lines
+211
to
+223
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just |
||
|
||
runnable_class: type[PytorchModelRunnable] = partial_class( | ||
PytorchModelRunnable, | ||
bento_model=bento_model, | ||
loader=load_model, | ||
runnable_class = partial_class( | ||
PytorchModelRunnable, bento_model=bento_model, loader=_load_model | ||
) | ||
for method_name, options in bento_model.info.signatures.items(): | ||
method_partial_kwargs = partial_kwargs.get(method_name) | ||
method_partial_kwargs = opts.partial_kwargs.get(method_name) | ||
runnable_class.add_method( | ||
make_pytorch_runnable_method(method_name, method_partial_kwargs), | ||
name=method_name, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, let’s call this
PytorchOptions
and only when import tobentoml.pytorch
we rename it toModelOptions
.Second, I think maybe it’s better to have single
compile_kwargs
dict inPytorchOptions
instead of polluting the name space ofPytorchOptions
. Maybe we will have only 2 entries:enable_compile
compile_kwargs