Skip to content

Commit

Permalink
fix: bug: Dataframes not serializing correctly in the new API (#4491)
Browse files Browse the repository at this point in the history
Fixes #4489

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed Feb 20, 2024
1 parent ad0d485 commit 6c2ac38
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/_bentoml_impl/client/http.py
Expand Up @@ -227,7 +227,7 @@ def _build_request(
return self.client.build_request(
"POST",
endpoint.route,
content=self.serde.serialize(kwargs),
content=self.serde.serialize(kwargs, endpoint.input),
headers=headers,
)

Expand Down Expand Up @@ -307,7 +307,7 @@ def _deserialize_output(self, data: bytes, endpoint: ClientEndpoint) -> t.Any:
elif ot == "bytes":
return data
else:
return self.serde.deserialize(data)
return self.serde.deserialize(data, endpoint.output)

def call(self, __name: str, /, *args: t.Any, **kwargs: t.Any) -> t.Any:
try:
Expand Down
89 changes: 79 additions & 10 deletions src/_bentoml_impl/serde.py
Expand Up @@ -9,12 +9,15 @@
from urllib.parse import unquote
from urllib.parse import urlparse

from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.datastructures import UploadFile
from typing_extensions import get_args

from _bentoml_sdk.typing_utils import is_list_type
from _bentoml_sdk.typing_utils import is_union_type
from _bentoml_sdk.validators import DataframeSchema
from _bentoml_sdk.validators import TensorSchema

if t.TYPE_CHECKING:
from starlette.requests import Request
Expand All @@ -36,11 +39,11 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
...

@abc.abstractmethod
def serialize(self, obj: t.Any) -> bytes:
def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
...

@abc.abstractmethod
def deserialize(self, obj_bytes: bytes) -> t.Any:
def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
...

async def parse_request(self, request: Request, cls: type[T]) -> T:
Expand All @@ -49,7 +52,73 @@ async def parse_request(self, request: Request, cls: type[T]) -> T:
return self.deserialize_model(json_str, cls)


class JSONSerde(Serde):
class GenericSerde:
def _encode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any:
if schema.get("type") == "tensor":
child_schema = TensorSchema(
format=schema.get("format", ""),
dtype=schema.get("dtype"),
shape=schema.get("shape"),
)
return child_schema.encode(child_schema.validate(obj))
if schema.get("type") == "dataframe":
child_schema = DataframeSchema(
orient=schema.get("orient", "records"), columns=schema.get("columns")
)
return child_schema.encode(child_schema.validate(obj))
if schema.get("type") == "array" and "items" in schema:
return [self._encode(v, schema["items"]) for v in obj]
if schema.get("type") == "object" and schema.get("properties"):
if isinstance(obj, BaseModel):
return obj.model_dump()
return {
k: self._encode(obj[k], child)
for k, child in schema["properties"].items()
if k in obj
}
return obj

def _decode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any:
if schema.get("type") == "tensor":
child_schema = TensorSchema(
format=schema.get("format", ""),
dtype=schema.get("dtype"),
shape=schema.get("shape"),
)
return child_schema.validate(obj)
if schema.get("type") == "dataframe":
child_schema = DataframeSchema(
orient=schema.get("orient", "records"), columns=schema.get("columns")
)
return child_schema.validate(obj)
if schema.get("type") == "array" and "items" in schema:
return [self._decode(v, schema["items"]) for v in obj]
if (
schema.get("type") == "object"
and schema.get("properties")
and isinstance(obj, t.Mapping)
):
return {
k: self._decode(obj[k], child)
for k, child in schema["properties"].items()
if k in obj
}
return obj

def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
return self.serialize_value(self._encode(obj, schema))

def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
return self._decode(self.deserialize_value(obj_bytes), schema)

def serialize_value(self, obj: t.Any) -> bytes:
raise NotImplementedError

def deserialize_value(self, obj_bytes: bytes) -> t.Any:
raise NotImplementedError


class JSONSerde(GenericSerde, Serde):
media_type = "application/json"

def serialize_model(self, model: IODescriptor) -> bytes:
Expand All @@ -60,10 +129,10 @@ def serialize_model(self, model: IODescriptor) -> bytes:
def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
return cls.model_validate_json(model_bytes)

def serialize(self, obj: t.Any) -> bytes:
def serialize_value(self, obj: t.Any) -> bytes:
return json.dumps(obj).encode("utf-8")

def deserialize(self, obj_bytes: bytes) -> t.Any:
def deserialize_value(self, obj_bytes: bytes) -> t.Any:
return json.loads(obj_bytes)


Expand Down Expand Up @@ -109,7 +178,7 @@ async def parse_request(self, request: Request, cls: type[T]) -> T:
return cls.model_validate(data)


class PickleSerde(Serde):
class PickleSerde(GenericSerde, Serde):
media_type = "application/vnd.bentoml+pickle"

def serialize_model(self, model: IODescriptor) -> bytes:
Expand All @@ -122,10 +191,10 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
obj = cls.model_validate(obj)
return obj

def serialize(self, obj: t.Any) -> bytes:
def serialize_value(self, obj: t.Any) -> bytes:
return pickle.dumps(obj)

def deserialize(self, obj_bytes: bytes) -> t.Any:
def deserialize_value(self, obj_bytes: bytes) -> t.Any:
return pickle.loads(obj_bytes)


Expand All @@ -145,12 +214,12 @@ def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
buffer = io.BytesIO(model_bytes)
return deserialize_from_arrow(cls, buffer)

def serialize(self, obj: t.Any) -> bytes:
def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
raise NotImplementedError(
"Serializing arbitrary object to Arrow is not supported"
)

def deserialize(self, obj_bytes: bytes) -> t.Any:
def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
raise NotImplementedError(
"Deserializing arbitrary object from Arrow is not supported"
)
Expand Down
24 changes: 13 additions & 11 deletions src/_bentoml_sdk/io_models.py
Expand Up @@ -143,23 +143,15 @@ async def to_http_response(cls, obj: t.Any, serde: Serde) -> Response:

structured_media_type = cls.media_type or serde.media_type

if not issubclass(cls, RootModel):
if cls.multipart_fields:
return Response(
"Multipart response is not supported yet", status_code=500
)
return Response(
content=serde.serialize_model(t.cast(IODescriptor, obj)),
media_type=structured_media_type,
)
if inspect.isasyncgen(obj):

async def async_stream() -> t.AsyncGenerator[str | bytes, None]:
async for item in obj:
if isinstance(item, (str, bytes)):
yield item
else:
yield serde.serialize_model(t.cast(IODescriptor, cls(item)))
obj_item = cls(item) if issubclass(cls, RootModel) else item
yield serde.serialize_model(t.cast(IODescriptor, obj_item))

return StreamingResponse(async_stream(), media_type=cls.mime_type())

Expand All @@ -170,9 +162,19 @@ def content_stream() -> t.Generator[str | bytes, None, None]:
if isinstance(item, (str, bytes)):
yield item
else:
yield serde.serialize_model(t.cast(IODescriptor, cls(item)))
obj_item = cls(item) if issubclass(cls, RootModel) else item
yield serde.serialize_model(t.cast(IODescriptor, obj_item))

return StreamingResponse(content_stream(), media_type=cls.mime_type())
elif not issubclass(cls, RootModel):
if cls.multipart_fields:
return Response(
"Multipart response is not supported yet", status_code=500
)
return Response(
content=serde.serialize_model(t.cast(IODescriptor, obj)),
media_type=structured_media_type,
)
else:
if is_file_type(type(obj)) and isinstance(serde, JSONSerde):
if isinstance(obj, pathlib.PurePath):
Expand Down
19 changes: 6 additions & 13 deletions src/_bentoml_sdk/validators.py
Expand Up @@ -207,12 +207,12 @@ def __get_pydantic_core_schema__(
self, source_type: t.Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
self._validate,
self.validate,
core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(self.encode),
)

def encode(self, arr: TensorType) -> bytes:
def encode(self, arr: TensorType) -> list[t.Any]:
if self.format == "numpy-array":
numpy_array = arr
elif self.format == "tf-tensor":
Expand All @@ -235,7 +235,7 @@ def framework_dtype(self) -> t.Any:
else:
return getattr(torch, dtype)

def _validate(self, obj: t.Any) -> t.Any:
def validate(self, obj: t.Any) -> t.Any:
arr: t.Any
if self.format == "numpy-array":
arr = np.array(obj, dtype=self.framework_dtype)
Expand Down Expand Up @@ -295,15 +295,8 @@ def __get_pydantic_core_schema__(
self, source_type: t.Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
self._validate,
(
core_schema.list_schema(core_schema.dict_schema())
if self.orient == "records"
else core_schema.dict_schema(
keys_schema=core_schema.str_schema(),
values_schema=core_schema.list_schema(),
)
),
self.validate,
core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(self.encode),
)

Expand All @@ -315,7 +308,7 @@ def encode(self, df: pd.DataFrame) -> list | dict:
else:
raise ValueError("Only 'records' and 'columns' are supported for orient")

def _validate(self, obj: t.Any) -> pd.DataFrame:
def validate(self, obj: t.Any) -> pd.DataFrame:
return pd.DataFrame(obj, columns=self.columns)


Expand Down
13 changes: 11 additions & 2 deletions src/bentoml/__init__.py
Expand Up @@ -111,6 +111,7 @@
from _bentoml_sdk import service
else:
from ._internal.utils import LazyLoader as _LazyLoader
from ._internal.utils.pkg import pkg_version_info

# ML Frameworks
catboost = _LazyLoader("bentoml.catboost", globals(), "bentoml.catboost")
Expand Down Expand Up @@ -172,11 +173,19 @@
_NEW_SDK_ATTRS = ["service", "runner_service", "api", "depends"]
_NEW_CLIENTS = ["SyncHTTPClient", "AsyncHTTPClient"]

if (ver := pkg_version_info("pydantic")) >= (2,):
import _bentoml_sdk
else:
_bentoml_sdk = None

def __getattr__(name: str) -> Any:
if name not in _NEW_SDK_ATTRS + _NEW_CLIENTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

import _bentoml_sdk
if _bentoml_sdk is None:
raise ImportError(
f"The new SDK runs on pydantic>=2.0.0, but the you have {'.'.join(map(str, ver))}. "
"Please upgrade it."
)

if name in _NEW_CLIENTS:
import _bentoml_impl.client
Expand Down

0 comments on commit 6c2ac38

Please sign in to comment.