Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: memory issue when push large bentos #4207

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/bentoml/_internal/cloud/base.py
Expand Up @@ -5,6 +5,7 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
from tempfile import SpooledTemporaryFile

from rich.console import Group
from rich.panel import Panel
Expand All @@ -17,6 +18,7 @@
from rich.progress import TimeRemainingColumn
from rich.progress import TransferSpeedColumn

from ...exceptions import BentoMLException
from ..bento import Bento
from ..bento import BentoStore
from ..models import Model
Expand All @@ -26,6 +28,78 @@
FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb


def io_wrapper(
memory: int,
*,
read_cb: t.Callable[[int], None] | None = None,
write_cb: t.Callable[[int], None] | None = None,
) -> CallbackIOWrapper | CallbackSpooledTemporaryFileIO:
"""
io_wrapper is a wrapper for SpooledTemporaryFileIO and CallbackIOWrapper
"""
if memory == -1:
return CallbackIOWrapper(read_cb=read_cb, write_cb=write_cb)
elif memory > 0:
return CallbackSpooledTemporaryFileIO(
memory * 1024**3, read_cb=read_cb, write_cb=write_cb
)
else:
raise BentoMLException(f"Option max_memory must be -1 or > 0, got {memory}")


class CallbackSpooledTemporaryFileIO(SpooledTemporaryFile):
"""
A SpooledTemporaryFile wrapper that calls
a callback when read/write is called
"""

read_cb: t.Callable[[int], None] | None
write_cb: t.Callable[[int], None] | None

def __init__(
self,
max_size: int = 0,
*,
read_cb: t.Callable[[int], None] | None = None,
write_cb: t.Callable[[int], None] | None = None,
):
self.read_cb = read_cb
self.write_cb = write_cb
super().__init__(max_size)

def read(self, *args):
res = super().read(*args)
if self.read_cb is not None:
self.read_cb(len(res))
return res

def write(self, s):
res = super().write(s)
if self.write_cb is not None:
if hasattr(s, "__len__"):
self.write_cb(len(s))
return res

def size(self) -> int:
"""
get the size of the file
"""
current_pos = self.tell()
self.seek(0, 2)
file_size = self.tell()
self.seek(current_pos)
return file_size

def chunk(self, start: int, end: int) -> bytes:
"""
chunk the file slice of [start, end)
"""
self.seek(start)
if end < 0 or start > end:
return self.read()
return self.read(end - start)


class CallbackIOWrapper(io.BytesIO):
read_cb: t.Callable[[int], None] | None
write_cb: t.Callable[[int], None] | None
Expand Down Expand Up @@ -57,6 +131,18 @@ def write(self, data: bytes) -> t.Any: # type: ignore # python buffer types ar
self.write_cb(len(data))
return res

def size(self) -> int:
"""
get the size of the buffer
"""
return super().getbuffer().nbytes

def chunk(self, start: int, end: int) -> bytes:
"""
chunk the buffer slice of [start, end)
"""
return super().getbuffer()[start:end]


class CloudClient(ABC):
log_progress = Progress(TextColumn("{task.description}"))
Expand Down
40 changes: 28 additions & 12 deletions src/bentoml/_internal/cloud/bentocloud.py
Expand Up @@ -28,6 +28,7 @@
from .base import FILE_CHUNK_SIZE
from .base import CallbackIOWrapper
from .base import CloudClient
from .base import io_wrapper
from .config import get_rest_api_client
from .deployment import Deployment
from .schemas import BentoApiSchema
Expand Down Expand Up @@ -70,13 +71,19 @@ def push_bento(
force: bool = False,
threads: int = 10,
context: str | None = None,
max_memory: int = -1,
):
with Live(self.progress_group):
upload_task_id = self.transmission_progress.add_task(
f'Pushing Bento "{bento.tag}"', start=False, visible=False
)
self._do_push_bento(
bento, upload_task_id, force=force, threads=threads, context=context
bento,
upload_task_id,
force=force,
threads=threads,
context=context,
max_memory=max_memory,
)

@inject
Expand All @@ -88,6 +95,7 @@ def _do_push_bento(
force: bool = False,
threads: int = 10,
context: str | None = None,
max_memory: int = -1,
model_store: ModelStore = Provide[BentoMLContainer.model_store],
):
yatai_rest_client = get_rest_api_client(context)
Expand All @@ -113,6 +121,7 @@ def push_model(model: Model) -> None:
force=force,
threads=threads,
context=context,
max_memory=max_memory,
)

futures: t.Iterator[None] = executor.map(push_model, models)
Expand Down Expand Up @@ -570,13 +579,19 @@ def push_model(
force: bool = False,
threads: int = 10,
context: str | None = None,
max_memory: int = -1,
):
with Live(self.progress_group):
upload_task_id = self.transmission_progress.add_task(
f'Pushing model "{model.tag}"', start=False, visible=False
)
self._do_push_model(
model, upload_task_id, force=force, threads=threads, context=context
model,
upload_task_id,
force=force,
threads=threads,
context=context,
max_memory=max_memory,
)

def _do_push_model(
Expand All @@ -587,6 +602,7 @@ def _do_push_model(
force: bool = False,
threads: int = 10,
context: str | None = None,
max_memory: int = -1,
):
yatai_rest_client = get_rest_api_client(context)
name = model.tag.name
Expand Down Expand Up @@ -667,7 +683,8 @@ def io_cb(x: int):
with io_mutex:
self.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
# limit the max memory usage when uploading model
with io_wrapper(max_memory, read_cb=io_cb) as tar_io:
with self.spin(text=f'Creating tar archive for model "{model.tag}"..'):
with tarfile.open(fileobj=tar_io, mode="w:") as tar:
tar.add(model.path, arcname="./")
Expand All @@ -676,7 +693,7 @@ def io_cb(x: int):
yatai_rest_client.start_upload_model(
model_repository_name=model_repository.name, version=version
)
file_size = tar_io.getbuffer().nbytes
file_size = tar_io.size()
self.transmission_progress.update(
upload_task_id,
description=f'Uploading model "{model.tag}"',
Expand Down Expand Up @@ -751,15 +768,14 @@ def chunk_upload(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1)
* FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
tar_io.chunk(
(chunk_number - 1) * FILE_CHUNK_SIZE,
chunk_number * FILE_CHUNK_SIZE,
)
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
else tar_io.chunk(
(chunk_number - 1) * FILE_CHUNK_SIZE, -1
)
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
Expand Down
9 changes: 8 additions & 1 deletion src/bentoml_cli/bentos.py
Expand Up @@ -281,8 +281,14 @@ def pull(shared_options: SharedOptions, bento_tag: str, force: bool) -> None: #
default=10,
help="Number of threads to use for upload",
)
@click.option(
"-m",
"--max_memory",
default=-1,
help="max memory usage in GB when pushing, default -1 means no limit",
)
@click.pass_obj
def push(shared_options: SharedOptions, bento_tag: str, force: bool, threads: int) -> None: # type: ignore (not accessed)
def push(shared_options: SharedOptions, bento_tag: str, force: bool, threads: int, max_memory: int) -> None: # type: ignore (not accessed)
"""Push Bento to a remote Bento store server."""
bento_obj = bento_store.get(bento_tag)
if not bento_obj:
Expand All @@ -292,6 +298,7 @@ def push(shared_options: SharedOptions, bento_tag: str, force: bool, threads: in
force=force,
threads=threads,
context=shared_options.cloud_context,
max_memory=max_memory,
)

@cli.command()
Expand Down
9 changes: 8 additions & 1 deletion src/bentoml_cli/models.py
Expand Up @@ -308,8 +308,14 @@ def pull(ctx: click.Context, model_tag: str | None, force: bool, bentofile: str)
default=10,
help="Number of threads to use for upload",
)
@click.option(
"-m",
"--max_memory",
default=-1,
help="max memory usage in GB when pushing, default -1 means no limit",
)
@click.pass_obj
def push(shared_options: SharedOptions, model_tag: str, force: bool, threads: int): # type: ignore (not accessed)
def push(shared_options: SharedOptions, model_tag: str, force: bool, threads: int, max_memory: int): # type: ignore (not accessed)
"""Push Model to a remote model store."""
model_obj = model_store.get(model_tag)
if not model_obj:
Expand All @@ -319,4 +325,5 @@ def push(shared_options: SharedOptions, model_tag: str, force: bool, threads: in
force=force,
threads=threads,
context=shared_options.cloud_context,
max_memory=max_memory,
)