Skip to content

Commit

Permalink
feat: picklev5 serialization (#4594)
Browse files Browse the repository at this point in the history
* feat: picklev5 serialization

Signed-off-by: Frost Ming <me@frostming.com>

* fix generic serde

Signed-off-by: Frost Ming <me@frostming.com>

* fix client send

Signed-off-by: Frost Ming <me@frostming.com>

---------

Signed-off-by: Frost Ming <me@frostming.com>
Co-authored-by: bojiang <5886138+bojiang@users.noreply.github.com>
  • Loading branch information
frostming and bojiang committed Mar 20, 2024
1 parent cf8f663 commit 41623a5
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 90 deletions.
43 changes: 31 additions & 12 deletions src/_bentoml_impl/client/http.py
Expand Up @@ -26,6 +26,7 @@
from bentoml._internal.utils.uri import uri_to_path
from bentoml.exceptions import BentoMLException

from ..serde import Payload
from .base import AbstractClient
from .base import ClientEndpoint

Expand All @@ -37,6 +38,7 @@
from ..serde import Serde

T = t.TypeVar("T", bound="HTTPClient[t.Any]")
A = t.TypeVar("A")

C = t.TypeVar("C", httpx.Client, httpx.AsyncClient)
AnyClient = t.TypeVar("AnyClient", httpx.Client, httpx.AsyncClient)
Expand All @@ -48,6 +50,14 @@ def is_http_url(url: str) -> bool:
return urlparse(url).scheme in {"http", "https"}


def to_async_iterable(iterable: t.Iterable[A]) -> t.AsyncIterable[A]:
async def _gen() -> t.AsyncIterator[A]:
for item in iterable:
yield item

return _gen()


@attr.define
class HTTPClient(AbstractClient, t.Generic[C]):
client_cls: t.ClassVar[type[httpx.Client] | type[httpx.AsyncClient]]
Expand Down Expand Up @@ -198,11 +208,15 @@ def _build_request(
if model.multipart_fields and self.media_type == "application/json":
return self._build_multipart(endpoint, model, headers)
else:
payload = self.serde.serialize_model(model)
headers.update(payload.headers)
return self.client.build_request(
"POST",
endpoint.route,
headers=headers,
content=self.serde.serialize_model(model),
content=to_async_iterable(payload.data)
if self.client_cls is httpx.AsyncClient
else payload.data,
)

for name, value in zip(endpoint.input["properties"], args):
Expand Down Expand Up @@ -230,10 +244,14 @@ def _build_request(
)
if has_file and self.media_type == "application/json":
return self._build_multipart(endpoint, kwargs, headers)
payload = self.serde.serialize(kwargs, endpoint.input)
headers.update(payload.headers)
return self.client.build_request(
"POST",
endpoint.route,
content=self.serde.serialize(kwargs, endpoint.input),
content=to_async_iterable(payload.data)
if self.client_cls is httpx.AsyncClient
else payload.data,
headers=headers,
)

Expand Down Expand Up @@ -318,18 +336,19 @@ def is_file_field(k: str) -> bool:
"POST", endpoint.route, data=data, files=files, headers=headers
)

def _deserialize_output(self, data: bytes, endpoint: ClientEndpoint) -> t.Any:
def _deserialize_output(self, payload: Payload, endpoint: ClientEndpoint) -> t.Any:
data = iter(payload.data)
if endpoint.output_spec is not None:
model = self.serde.deserialize_model(data, endpoint.output_spec)
model = self.serde.deserialize_model(payload, endpoint.output_spec)
if isinstance(model, RootModel):
return model.root # type: ignore
return model
elif (ot := endpoint.output.get("type")) == "string":
return data.decode("utf-8")
return bytes(next(data)).decode("utf-8")
elif ot == "bytes":
return data
return bytes(next(data))
else:
return self.serde.deserialize(data, endpoint.output)
return self.serde.deserialize(payload, endpoint.output)

def call(self, __name: str, /, *args: t.Any, **kwargs: t.Any) -> t.Any:
try:
Expand Down Expand Up @@ -429,15 +448,15 @@ def _call(
self._opened_files.clear()

def _parse_response(self, endpoint: ClientEndpoint, resp: httpx.Response) -> t.Any:
data = resp.read()
return self._deserialize_output(data, endpoint)
payload = Payload((resp.read(),), resp.headers)
return self._deserialize_output(payload, endpoint)

def _parse_stream_response(
self, endpoint: ClientEndpoint, resp: httpx.Response
) -> t.Generator[t.Any, None, None]:
try:
for data in resp.iter_bytes():
yield self._deserialize_output(data, endpoint)
yield self._deserialize_output(Payload((data,), resp.headers), endpoint)
finally:
resp.close()

Expand Down Expand Up @@ -534,14 +553,14 @@ async def _parse_response(
self, endpoint: ClientEndpoint, resp: httpx.Response
) -> t.Any:
data = await resp.aread()
return self._deserialize_output(data, endpoint)
return self._deserialize_output(Payload((data,), resp.headers), endpoint)

async def _parse_stream_response(
self, endpoint: ClientEndpoint, resp: httpx.Response
) -> t.AsyncGenerator[t.Any, None]:
try:
async for data in resp.aiter_bytes():
yield self._deserialize_output(data, endpoint)
yield self._deserialize_output(Payload((data,), resp.headers), endpoint)
finally:
await resp.aclose()

Expand Down
144 changes: 84 additions & 60 deletions src/_bentoml_impl/serde.py
Expand Up @@ -9,6 +9,7 @@
from urllib.parse import unquote
from urllib.parse import urlparse

import attrs
from pydantic import BaseModel
from starlette.datastructures import Headers
from starlette.datastructures import UploadFile
Expand All @@ -24,53 +25,79 @@

from _bentoml_sdk import IODescriptor


T = t.TypeVar("T", bound="IODescriptor")


@attrs.frozen
class Payload:
data: t.Iterable[bytes | memoryview]
metadata: t.Mapping[str, str] = attrs.field(factory=dict)

def total_bytes(self) -> int:
return sum(len(d) for d in self.data)

@property
def headers(self) -> t.Mapping[str, str]:
return {"content-length": str(self.total_bytes()), **self.metadata}


@attrs.frozen
class SerializationInfo:
mode: str

def mode_is_json(self) -> bool:
return self.mode == "json"


class Serde(abc.ABC):
media_type: str

@abc.abstractmethod
def serialize_model(self, model: IODescriptor) -> bytes:
def serialize_model(self, model: IODescriptor) -> Payload:
...

@abc.abstractmethod
def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
def deserialize_model(self, payload: Payload, cls: type[T]) -> T:
...

@abc.abstractmethod
def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> Payload:
...

@abc.abstractmethod
def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
def deserialize(self, payload: Payload, schema: dict[str, t.Any]) -> t.Any:
...

async def parse_request(self, request: Request, cls: type[T]) -> T:
"""Parse a input model from HTTP request"""
json_str = await request.body()
return self.deserialize_model(json_str, cls)
return self.deserialize_model(
Payload((json_str,), metadata=request.headers), cls
)


class GenericSerde:
def _encode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any:
mode = "json" if isinstance(self, JSONSerde) else "python"
info = SerializationInfo(mode=mode)
if schema.get("type") == "tensor":
child_schema = TensorSchema(
format=schema.get("format", ""),
dtype=schema.get("dtype"),
shape=schema.get("shape"),
)
return child_schema.encode(child_schema.validate(obj))
return child_schema.encode(child_schema.validate(obj), info)
if schema.get("type") == "dataframe":
child_schema = DataframeSchema(
orient=schema.get("orient", "records"), columns=schema.get("columns")
)
return child_schema.encode(child_schema.validate(obj))
return child_schema.encode(child_schema.validate(obj), info)
if schema.get("type") == "array" and "items" in schema:
return [self._encode(v, schema["items"]) for v in obj]
if schema.get("type") == "object" and schema.get("properties"):
if isinstance(obj, BaseModel):
return obj.model_dump()
return obj.model_dump(mode=mode)
return {
k: self._encode(obj[k], child)
for k, child in schema["properties"].items()
Expand Down Expand Up @@ -105,35 +132,39 @@ def _decode(self, obj: t.Any, schema: dict[str, t.Any]) -> t.Any:
}
return obj

def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> Payload:
return self.serialize_value(self._encode(obj, schema))

def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
return self._decode(self.deserialize_value(obj_bytes), schema)
def deserialize(self, payload: Payload, schema: dict[str, t.Any]) -> t.Any:
return self._decode(self.deserialize_value(payload), schema)

def serialize_value(self, obj: t.Any) -> bytes:
def serialize_value(self, obj: t.Any) -> Payload:
raise NotImplementedError

def deserialize_value(self, obj_bytes: bytes) -> t.Any:
def deserialize_value(self, payload: Payload) -> t.Any:
raise NotImplementedError


class JSONSerde(GenericSerde, Serde):
media_type = "application/json"

def serialize_model(self, model: IODescriptor) -> bytes:
return model.model_dump_json(
exclude=set(getattr(model, "multipart_fields", set()))
).encode("utf-8")
def serialize_model(self, model: IODescriptor) -> Payload:
return Payload(
(
model.model_dump_json(
exclude=set(getattr(model, "multipart_fields", set()))
).encode("utf-8"),
)
)

def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
return cls.model_validate_json(model_bytes)
def deserialize_model(self, payload: Payload, cls: type[T]) -> T:
return cls.model_validate_json(b"".join(payload.data) or b"{}")

def serialize_value(self, obj: t.Any) -> bytes:
return json.dumps(obj).encode("utf-8")
def serialize_value(self, obj: t.Any) -> Payload:
return Payload((json.dumps(obj).encode("utf-8"),))

def deserialize_value(self, obj_bytes: bytes) -> t.Any:
return json.loads(obj_bytes)
def deserialize_value(self, payload: Payload) -> t.Any:
return json.loads(b"".join(payload.data) or b"{}")


class MultipartSerde(JSONSerde):
Expand Down Expand Up @@ -181,52 +212,45 @@ async def parse_request(self, request: Request, cls: type[T]) -> T:
class PickleSerde(GenericSerde, Serde):
media_type = "application/vnd.bentoml+pickle"

def serialize_model(self, model: IODescriptor) -> bytes:
def serialize_model(self, model: IODescriptor) -> Payload:
model_data = model.model_dump()
return pickle.dumps(model_data)
return self.serialize_value(model_data)

def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
obj = pickle.loads(model_bytes)
def deserialize_model(self, payload: Payload, cls: type[T]) -> T:
obj = self.deserialize_value(payload)
if not isinstance(obj, cls):
obj = cls.model_validate(obj)
return obj

def serialize_value(self, obj: t.Any) -> bytes:
return pickle.dumps(obj)

def deserialize_value(self, obj_bytes: bytes) -> t.Any:
return pickle.loads(obj_bytes)


class ArrowSerde(Serde):
media_type = "application/vnd.bentoml+arrow"

def serialize_model(self, model: IODescriptor) -> bytes:
from .arrow import serialize_to_arrow

buffer = io.BytesIO()
serialize_to_arrow(model, buffer)
return buffer.getvalue()

def deserialize_model(self, model_bytes: bytes, cls: type[T]) -> T:
from .arrow import deserialize_from_arrow

buffer = io.BytesIO(model_bytes)
return deserialize_from_arrow(cls, buffer)

def serialize(self, obj: t.Any, schema: dict[str, t.Any]) -> bytes:
raise NotImplementedError(
"Serializing arbitrary object to Arrow is not supported"
)

def deserialize(self, obj_bytes: bytes, schema: dict[str, t.Any]) -> t.Any:
raise NotImplementedError(
"Deserializing arbitrary object from Arrow is not supported"
)
def serialize_value(self, obj: t.Any) -> Payload:
buffers: list[pickle.PickleBuffer] = []
main_bytes = pickle.dumps(obj, protocol=5, buffer_callback=buffers.append)
data: list[bytes | memoryview] = [main_bytes]
lengths = [len(main_bytes)]
for buff in buffers:
data.append(buff.raw())
lengths.append(len(data[-1]))
buff.release()
metadata = {"buffer-lengths": ",".join(map(str, lengths))}
return Payload(data, metadata)

def deserialize_value(self, payload: Payload) -> t.Any:
if "buffer-lengths" not in payload.metadata:
return pickle.loads(b"".join(payload.data))
buffer_lengths = list(map(int, payload.metadata["buffer-lengths"].split(",")))
data_stream = b"".join(payload.data)
data = memoryview(data_stream)
start = buffer_lengths[0]
main_bytes = data[:start]
buffers: list[pickle.PickleBuffer] = []
for length in buffer_lengths[1:]:
buffers.append(pickle.PickleBuffer(data[start : start + length]))
start += length
return pickle.loads(main_bytes, buffers=buffers)


ALL_SERDE: t.Mapping[str, type[Serde]] = {
s.media_type: s for s in [JSONSerde, PickleSerde, ArrowSerde, MultipartSerde]
s.media_type: s for s in [JSONSerde, PickleSerde, MultipartSerde]
}
# Special case for application/x-www-form-urlencoded
ALL_SERDE["application/x-www-form-urlencoded"] = MultipartSerde

0 comments on commit 41623a5

Please sign in to comment.