Skip to content

Commit

Permalink
fix(server): clean the request resources after the response is consum…
Browse files Browse the repository at this point in the history
…ed (#4481)

* fix(server): clean the request resources after the response is consumed

Signed-off-by: Frost Ming <me@frostming.com>

* fix: remove unnecessary attr params

Signed-off-by: Frost Ming <me@frostming.com>

* fix: set attrs minimum version

Signed-off-by: Frost Ming <me@frostming.com>

---------

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed Feb 5, 2024
1 parent a13a913 commit e9e40b5
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"Jinja2>=3.0.1",
"PyYAML>=5.0",
"aiohttp",
"attrs>=21.1.0",
"attrs>=22.2.0",
"cattrs>=22.1.0,<23.2.0",
"circus>=0.17.0,!=0.17.2",
"click>=7.0",
Expand Down
105 changes: 53 additions & 52 deletions src/_bentoml_impl/server/app.py
Expand Up @@ -394,6 +394,8 @@ async def inner_infer(
)(value)

async def api_endpoint(self, name: str, request: Request) -> Response:
from starlette.background import BackgroundTask

from _bentoml_sdk.io_models import ARGS
from _bentoml_sdk.io_models import KWARGS
from bentoml._internal.container import BentoMLContainer
Expand All @@ -409,58 +411,57 @@ async def api_endpoint(self, name: str, request: Request) -> Response:
method = self.service.apis[name]
func = getattr(self._service_instance, name)
ctx = self.service.context
try:
serde = ALL_SERDE[media_type]()
input_data = await method.input_spec.from_http_request(request, serde)
input_args: tuple[t.Any, ...] = ()
input_params = {k: getattr(input_data, k) for k in input_data.model_fields}
if method.ctx_param is not None:
input_params[method.ctx_param] = ctx
if ARGS in input_params:
input_args = tuple(input_params.pop(ARGS))
if KWARGS in input_params:
input_params.update(input_params.pop(KWARGS))

original_func = get_original_func(func)

if method.batchable:
output = await self.batch_infer(name, input_args, input_params)
elif inspect.iscoroutinefunction(original_func):
output = await func(*input_args, **input_params)
elif inspect.isasyncgenfunction(original_func):
output = func(*input_args, **input_params)
elif inspect.isgeneratorfunction(original_func):

async def inner() -> t.AsyncGenerator[t.Any, None]:
gen = func(*input_args, **input_params)
while True:
try:
yield await self._to_thread(next, gen)
except StopIteration:
serde = ALL_SERDE[media_type]()
input_data = await method.input_spec.from_http_request(request, serde)
input_args: tuple[t.Any, ...] = ()
input_params = {k: getattr(input_data, k) for k in input_data.model_fields}
if method.ctx_param is not None:
input_params[method.ctx_param] = ctx
if ARGS in input_params:
input_args = tuple(input_params.pop(ARGS))
if KWARGS in input_params:
input_params.update(input_params.pop(KWARGS))

original_func = get_original_func(func)

if method.batchable:
output = await self.batch_infer(name, input_args, input_params)
elif inspect.iscoroutinefunction(original_func):
output = await func(*input_args, **input_params)
elif inspect.isasyncgenfunction(original_func):
output = func(*input_args, **input_params)
elif inspect.isgeneratorfunction(original_func):

async def inner() -> t.AsyncGenerator[t.Any, None]:
gen = func(*input_args, **input_params)
while True:
try:
yield await self._to_thread(next, gen)
except StopIteration:
break
except RuntimeError as e:
if "StopIteration" in str(e):
break
except RuntimeError as e:
if "StopIteration" in str(e):
break
raise
raise

output = inner()
else:
output = await self._to_thread(func, *input_args, **input_params)

response = await method.output_spec.to_http_response(output, serde)
response.headers.update({"Server": f"BentoML Service/{self.service.name}"})

if method.ctx_param is not None:
response.status_code = ctx.response.status_code
response.headers.update(ctx.response.metadata)
set_cookies(response, ctx.response.cookies)
if trace_context.request_id is not None:
response.headers["X-BentoML-Request-ID"] = str(trace_context.request_id)
if (
BentoMLContainer.http.response.trace_id.get()
and trace_context.trace_id is not None
):
response.headers["X-BentoML-Trace-ID"] = str(trace_context.trace_id)
finally:
await request.close()
output = inner()
else:
output = await self._to_thread(func, *input_args, **input_params)

response = await method.output_spec.to_http_response(output, serde)
response.headers.update({"Server": f"BentoML Service/{self.service.name}"})

if method.ctx_param is not None:
response.status_code = ctx.response.status_code
response.headers.update(ctx.response.metadata)
set_cookies(response, ctx.response.cookies)
if trace_context.request_id is not None:
response.headers["X-BentoML-Request-ID"] = str(trace_context.request_id)
if (
BentoMLContainer.http.response.trace_id.get()
and trace_context.trace_id is not None
):
response.headers["X-BentoML-Trace-ID"] = str(trace_context.trace_id)
# clean the request resources after the response is consumed.
response.background = BackgroundTask(request.close)
return response

0 comments on commit e9e40b5

Please sign in to comment.