Skip to content

Commit

Permalink
fix: memory issue when push large bentos
Browse files Browse the repository at this point in the history
  • Loading branch information
xianml committed Sep 26, 2023
1 parent f3a63c4 commit 7b94cf7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
54 changes: 54 additions & 0 deletions src/bentoml/_internal/cloud/base.py
Expand Up @@ -5,6 +5,7 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
from tempfile import SpooledTemporaryFile

from rich.console import Group
from rich.panel import Panel
Expand All @@ -24,6 +25,59 @@
from ..tag import Tag

FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb
SPOOLED_FILE_MAX_SIZE = 5 * 1024 * 1024 * 1024 # 5GB


class CallbackSpooledTemporaryFileIO(SpooledTemporaryFile):
"""
A SpooledTemporaryFile wrapper that calls a callback when read/write is called
"""

read_cb: t.Callable[[int], None] | None
write_cb: t.Callable[[int], None] | None

def __init__(
self,
max_size: int = 0,
*,
read_cb: t.Callable[[int], None] | None = None,
write_cb: t.Callable[[int], None] | None = None,
):
self.read_cb = read_cb
self.write_cb = write_cb
super().__init__(max_size)

def read(self, *args):
res = super().read(*args)
if self.read_cb is not None:
self.read_cb(len(res))
return res

def write(self, s): # type: ignore # python buffer types are too new and seem to not support something like Buffer+Sized as of now
res = super().write(s)
if self.write_cb is not None:
if hasattr(s, "__len__"):
self.write_cb(len(s))
return res

def size(self) -> int:
"""
get the size of the file
"""
current_pos = self.tell()
self.seek(0, 2)
file_size = self.tell()
self.seek(current_pos)
return file_size

def chunk(self, start: int, end: int) -> bytes:
"""
chunk the file slice of [start, end)
"""
self.seek(start)
if end < 0 or start > end:
return self.read()
return self.read(end - start)


class CallbackIOWrapper(io.BytesIO):
Expand Down
23 changes: 13 additions & 10 deletions src/bentoml/_internal/cloud/bentocloud.py
Expand Up @@ -26,7 +26,9 @@
from ..tag import Tag
from ..utils import calc_dir_size
from .base import FILE_CHUNK_SIZE
from .base import SPOOLED_FILE_MAX_SIZE
from .base import CallbackIOWrapper
from .base import CallbackSpooledTemporaryFileIO
from .base import CloudClient
from .config import get_rest_api_client
from .deployment import Deployment
Expand Down Expand Up @@ -667,7 +669,9 @@ def io_cb(x: int):
with io_mutex:
self.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
with CallbackSpooledTemporaryFileIO(
SPOOLED_FILE_MAX_SIZE, read_cb=io_cb
) as tar_io:
with self.spin(text=f'Creating tar archive for model "{model.tag}"..'):
with tarfile.open(fileobj=tar_io, mode="w:") as tar:
tar.add(model.path, arcname="./")
Expand All @@ -676,7 +680,7 @@ def io_cb(x: int):
yatai_rest_client.start_upload_model(
model_repository_name=model_repository.name, version=version
)
file_size = tar_io.getbuffer().nbytes
file_size = tar_io.size()
self.transmission_progress.update(
upload_task_id,
description=f'Uploading model "{model.tag}"',
Expand Down Expand Up @@ -751,15 +755,14 @@ def chunk_upload(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1)
* FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
tar_io.chunk(
(chunk_number - 1) * FILE_CHUNK_SIZE,
chunk_number * FILE_CHUNK_SIZE,
)
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
else tar_io.chunk(
(chunk_number - 1) * FILE_CHUNK_SIZE, -1
)
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
Expand Down

0 comments on commit 7b94cf7

Please sign in to comment.