From 6c2ac389799b6f18ec529904139a09af5d0349f6 Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Tue, 20 Feb 2024 15:31:14 +0800 Subject: [PATCH] fix: bug: Dataframes not serializing correctly in the new API (#4491) Fixes #4489 Signed-off-by: Frost Ming --- src/_bentoml_impl/client/http.py | 4 +- src/_bentoml_impl/serde.py | 89 ++++++++++++++++++++++++++++---- src/_bentoml_sdk/io_models.py | 24 +++++---- src/_bentoml_sdk/validators.py | 19 +++---- src/bentoml/__init__.py | 13 ++++- 5 files changed, 111 insertions(+), 38 deletions(-) diff --git a/src/_bentoml_impl/client/http.py b/src/_bentoml_impl/client/http.py index e265b97f028..37e4b3f6723 100644 --- a/src/_bentoml_impl/client/http.py +++ b/src/_bentoml_impl/client/http.py @@ -227,7 +227,7 @@ def _build_request( return self.client.build_request( "POST", endpoint.route, - content=self.serde.serialize(kwargs), + content=self.serde.serialize(kwargs, endpoint.input), headers=headers, ) @@ -307,7 +307,7 @@ def _deserialize_output(self, data: bytes, endpoint: ClientEndpoint) -> t.Any: elif ot == "bytes": return data else: - return self.serde.deserialize(data) + return self.serde.deserialize(data, endpoint.output) def call(self, __name: str, /, *args: t.Any, **kwargs: t.Any) -> t.Any: try: diff --git a/src/_bentoml_impl/serde.py b/src/_bentoml_impl/serde.py index 95054dd3ef6..2d0938a4c00 100644 --- a/src/_bentoml_impl/serde.py +++ b/src/_bentoml_impl/serde.py @@ -9,12 +9,15 @@ from urllib.parse import unquote from urllib.parse import urlparse +from pydantic import BaseModel from starlette.datastructures import Headers from starlette.datastructures import UploadFile from typing_extensions import get_args from _bentoml_sdk.typing_utils import is_list_type from _bentoml_sdk.typing_utils import is_union_type +from _bentoml_sdk.validators import DataframeSchema +from _bentoml_sdk.validators import TensorSchema if t.TYPE_CHECKING: from starlette.requests import Request @@ -36,11 +39,11 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T: ... @abc.abstractmethod - def serialize(self, obj: t.Any) -> bytes: + def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes: ... @abc.abstractmethod - def deserialize(self, obj_bytes: bytes) -> t.Any: + def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any: ... async def parse_request(self, request: Request, cls: type[T]) -> T: @@ -49,7 +52,73 @@ async def parse_request(self, request: Request, cls: type[T]) -> T: return self.deserialize_model(json_str, cls) -class JSONSerde(Serde): +class GenericSerde: + def _encode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any: + if schema.get("type") == "tensor": + child_schema = TensorSchema( + format=schema.get("format", ""), + dtype=schema.get("dtype"), + shape=schema.get("shape"), + ) + return child_schema.encode(child_schema.validate(obj)) + if schema.get("type") == "dataframe": + child_schema = DataframeSchema( + orient=schema.get("orient", "records"), columns=schema.get("columns") + ) + return child_schema.encode(child_schema.validate(obj)) + if schema.get("type") == "array" and "items" in schema: + return [self._encode(v, schema["items"]) for v in obj] + if schema.get("type") == "object" and schema.get("properties"): + if isinstance(obj, BaseModel): + return obj.model_dump() + return { + k: self._encode(obj[k], child) + for k, child in schema["properties"].items() + if k in obj + } + return obj + + def _decode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any: + if schema.get("type") == "tensor": + child_schema = TensorSchema( + format=schema.get("format", ""), + dtype=schema.get("dtype"), + shape=schema.get("shape"), + ) + return child_schema.validate(obj) + if schema.get("type") == "dataframe": + child_schema = DataframeSchema( + orient=schema.get("orient", "records"), columns=schema.get("columns") + ) + return child_schema.validate(obj) + if schema.get("type") == "array" and "items" in schema: + return [self._decode(v, schema["items"]) for v in obj] + if ( + schema.get("type") == "object" + and schema.get("properties") + and isinstance(obj, t.Mapping) + ): + return { + k: self._decode(obj[k], child) + for k, child in schema["properties"].items() + if k in obj + } + return obj + + def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes: + return self.serialize_value(self._encode(obj, schema)) + + def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any: + return self._decode(self.deserialize_value(obj_bytes), schema) + + def serialize_value(self, obj: t.Any) -> bytes: + raise NotImplementedError + + def deserialize_value(self, obj_bytes: bytes) -> t.Any: + raise NotImplementedError + + +class JSONSerde(GenericSerde, Serde): media_type = "application/json" def serialize_model(self, model: IODescriptor) -> bytes: @@ -60,10 +129,10 @@ def serialize_model(self, model: IODescriptor) -> bytes: def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T: return cls.model_validate_json(model_bytes) - def serialize(self, obj: t.Any) -> bytes: + def serialize_value(self, obj: t.Any) -> bytes: return json.dumps(obj).encode("utf-8") - def deserialize(self, obj_bytes: bytes) -> t.Any: + def deserialize_value(self, obj_bytes: bytes) -> t.Any: return json.loads(obj_bytes) @@ -109,7 +178,7 @@ async def parse_request(self, request: Request, cls: type[T]) -> T: return cls.model_validate(data) -class PickleSerde(Serde): +class PickleSerde(GenericSerde, Serde): media_type = "application/vnd.bentoml+pickle" def serialize_model(self, model: IODescriptor) -> bytes: @@ -122,10 +191,10 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T: obj = cls.model_validate(obj) return obj - def serialize(self, obj: t.Any) -> bytes: + def serialize_value(self, obj: t.Any) -> bytes: return pickle.dumps(obj) - def deserialize(self, obj_bytes: bytes) -> t.Any: + def deserialize_value(self, obj_bytes: bytes) -> t.Any: return pickle.loads(obj_bytes) @@ -145,12 +214,12 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T: buffer = io.BytesIO(model_bytes) return deserialize_from_arrow(cls, buffer) - def serialize(self, obj: t.Any) -> bytes: + def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes: raise NotImplementedError( "Serializing arbitrary object to Arrow is not supported" ) - def deserialize(self, obj_bytes: bytes) -> t.Any: + def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any: raise NotImplementedError( "Deserializing arbitrary object from Arrow is not supported" ) diff --git a/src/_bentoml_sdk/io_models.py b/src/_bentoml_sdk/io_models.py index 71a34999fc0..49f71416cce 100644 --- a/src/_bentoml_sdk/io_models.py +++ b/src/_bentoml_sdk/io_models.py @@ -143,15 +143,6 @@ async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response: structured_media_type = cls.media_type or serde.media_type - if not issubclass(cls, RootModel): - if cls.multipart_fields: - return Response( - "Multipart response is not supported yet", status_code=500 - ) - return Response( - content=serde.serialize_model(t.cast(IODescriptor, obj)), - media_type=structured_media_type, - ) if inspect.isasyncgen(obj): async def async_stream() -> t.AsyncGenerator[str | bytes, None]: @@ -159,7 +150,8 @@ async def async_stream() -> t.AsyncGenerator[str | bytes, None]: if isinstance(item, (str, bytes)): yield item else: - yield serde.serialize_model(t.cast(IODescriptor, cls(item))) + obj_item = cls(item) if issubclass(cls, RootModel) else item + yield serde.serialize_model(t.cast(IODescriptor, obj_item)) return StreamingResponse(async_stream(), media_type=cls.mime_type()) @@ -170,9 +162,19 @@ def content_stream() -> t.Generator[str | bytes, None, None]: if isinstance(item, (str, bytes)): yield item else: - yield serde.serialize_model(t.cast(IODescriptor, cls(item))) + obj_item = cls(item) if issubclass(cls, RootModel) else item + yield serde.serialize_model(t.cast(IODescriptor, obj_item)) return StreamingResponse(content_stream(), media_type=cls.mime_type()) + elif not issubclass(cls, RootModel): + if cls.multipart_fields: + return Response( + "Multipart response is not supported yet", status_code=500 + ) + return Response( + content=serde.serialize_model(t.cast(IODescriptor, obj)), + media_type=structured_media_type, + ) else: if is_file_type(type(obj)) and isinstance(serde, JSONSerde): if isinstance(obj, pathlib.PurePath): diff --git a/src/_bentoml_sdk/validators.py b/src/_bentoml_sdk/validators.py index 4cacb3ab810..54574716f24 100644 --- a/src/_bentoml_sdk/validators.py +++ b/src/_bentoml_sdk/validators.py @@ -207,12 +207,12 @@ def __get_pydantic_core_schema__( self, source_type: t.Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( - self._validate, + self.validate, core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema(self.encode), ) - def encode(self, arr: TensorType) -> bytes: + def encode(self, arr: TensorType) -> list[t.Any]: if self.format == "numpy-array": numpy_array = arr elif self.format == "tf-tensor": @@ -235,7 +235,7 @@ def framework_dtype(self) -> t.Any: else: return getattr(torch, dtype) - def _validate(self, obj: t.Any) -> t.Any: + def validate(self, obj: t.Any) -> t.Any: arr: t.Any if self.format == "numpy-array": arr = np.array(obj, dtype=self.framework_dtype) @@ -295,15 +295,8 @@ def __get_pydantic_core_schema__( self, source_type: t.Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: return core_schema.no_info_after_validator_function( - self._validate, - ( - core_schema.list_schema(core_schema.dict_schema()) - if self.orient == "records" - else core_schema.dict_schema( - keys_schema=core_schema.str_schema(), - values_schema=core_schema.list_schema(), - ) - ), + self.validate, + core_schema.any_schema(), serialization=core_schema.plain_serializer_function_ser_schema(self.encode), ) @@ -315,7 +308,7 @@ def encode(self, df: pd.DataFrame) -> list | dict: else: raise ValueError("Only 'records' and 'columns' are supported for orient") - def _validate(self, obj: t.Any) -> pd.DataFrame: + def validate(self, obj: t.Any) -> pd.DataFrame: return pd.DataFrame(obj, columns=self.columns) diff --git a/src/bentoml/__init__.py b/src/bentoml/__init__.py index abbf55b53dd..b849aa162f5 100644 --- a/src/bentoml/__init__.py +++ b/src/bentoml/__init__.py @@ -111,6 +111,7 @@ from _bentoml_sdk import service else: from ._internal.utils import LazyLoader as _LazyLoader + from ._internal.utils.pkg import pkg_version_info # ML Frameworks catboost = _LazyLoader("bentoml.catboost", globals(), "bentoml.catboost") @@ -172,11 +173,19 @@ _NEW_SDK_ATTRS = ["service", "runner_service", "api", "depends"] _NEW_CLIENTS = ["SyncHTTPClient", "AsyncHTTPClient"] + if (ver := pkg_version_info("pydantic")) >= (2,): + import _bentoml_sdk + else: + _bentoml_sdk = None + def __getattr__(name: str) -> Any: if name not in _NEW_SDK_ATTRS + _NEW_CLIENTS: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - import _bentoml_sdk + if _bentoml_sdk is None: + raise ImportError( + f"The new SDK runs on pydantic>=2.0.0, but the you have {'.'.join(map(str, ver))}. " + "Please upgrade it." + ) if name in _NEW_CLIENTS: import _bentoml_impl.client