Skip to content

Commit

Permalink
use httpx in cloud clients
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon committed Oct 10, 2023
1 parent d645d67 commit f6444f0
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 109 deletions.
87 changes: 46 additions & 41 deletions src/bentoml/_internal/cloud/bentocloud.py
Expand Up @@ -10,7 +10,7 @@
from tempfile import NamedTemporaryFile

import fs
import requests
import httpx
from rich.live import Live
from simple_di import Provide
from simple_di import inject
Expand Down Expand Up @@ -266,7 +266,7 @@ def filter_(
)
try:
if presigned_upload_url is not None:
resp = requests.put(presigned_upload_url, data=tar_io)
resp = httpx.put(presigned_upload_url, content=tar_io)
if resp.status_code != 200:
finish_req = FinishUploadBentoSchema(
status=BentoUploadStatus.FAILED,
Expand Down Expand Up @@ -321,8 +321,8 @@ def chunk_upload(
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = requests.put(
remote_bento.presigned_upload_url, data=chunk_io
resp = httpx.put(
remote_bento.presigned_upload_url, content=chunk_io
)
if resp.status_code != 200:
return FinishUploadBentoSchema(
Expand Down Expand Up @@ -510,27 +510,28 @@ def pull_model(model_tag: Tag):
name, version
)
presigned_download_url = remote_bento.presigned_download_url
response = requests.get(presigned_download_url, stream=True)

if response.status_code != 200:
raise BentoMLException(
f'Failed to download bento "{_tag}": {response.text}'
)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with NamedTemporaryFile() as tar_file:
self.transmission_progress.update(
download_task_id,
completed=0,
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_content(block_size):
with httpx.stream("GET", presigned_download_url) as response:
if response.status_code != 200:
raise BentoMLException(
f'Failed to download bento "{_tag}": {response.text}'
)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
self.transmission_progress.update(
download_task_id, advance=len(data)
download_task_id,
completed=0,
total=total_size_in_bytes,
visible=True,
)
tar_file.write(data)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_bytes(block_size):
self.transmission_progress.update(
download_task_id, advance=len(data)
)
tar_file.write(data)

self.log_progress.add_task(
f'[bold green]Finished downloading all bento "{_tag}" files'
)
Expand Down Expand Up @@ -707,7 +708,7 @@ def io_cb(x: int):
)
try:
if presigned_upload_url is not None:
resp = requests.put(presigned_upload_url, data=tar_io)
resp = httpx.put(presigned_upload_url, content=tar_io)
if resp.status_code != 200:
finish_req = FinishUploadModelSchema(
status=ModelUploadStatus.FAILED,
Expand Down Expand Up @@ -763,8 +764,8 @@ def chunk_upload(
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = requests.put(
remote_model.presigned_upload_url, data=chunk_io
resp = httpx.put(
remote_model.presigned_upload_url, content=chunk_io
)
if resp.status_code != 200:
return FinishUploadModelSchema(
Expand Down Expand Up @@ -833,6 +834,7 @@ def chunk_upload(
version=version,
req=finish_req,
)

if finish_req.status != ModelUploadStatus.SUCCESS:
self.log_progress.add_task(
f'[bold red]Failed pushing model "{model.tag}" : {finish_req.reason}'
Expand Down Expand Up @@ -954,25 +956,28 @@ def _do_pull_model(
)
presigned_download_url = remote_model.presigned_download_url

response = requests.get(presigned_download_url, stream=True)
if response.status_code != 200:
raise BentoMLException(
f'Failed to download model "{_tag}": {response.text}'
with NamedTemporaryFile() as tar_file:
with httpx.stream("GET", presigned_download_url) as response:
if response.status_code != 200:
raise BentoMLException(
f'Failed to download model "{_tag}": {response.text}'
)

total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
self.transmission_progress.update(
download_task_id,
description=f'Downloading model "{_tag}"',
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_bytes(block_size):
self.transmission_progress.update(
download_task_id, advance=len(data)
)
tar_file.write(data)

total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with NamedTemporaryFile() as tar_file:
self.transmission_progress.update(
download_task_id,
description=f'Downloading model "{_tag}"',
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_content(block_size):
self.transmission_progress.update(download_task_id, advance=len(data))
tar_file.write(data)
self.log_progress.add_task(
f'[bold green]Finished downloading model "{_tag}" files'
)
Expand Down
54 changes: 27 additions & 27 deletions src/bentoml/_internal/cloud/client.py
Expand Up @@ -4,7 +4,7 @@
import typing as t
from urllib.parse import urljoin

import requests
import httpx

from ...exceptions import CloudRESTApiClientError
from ..configuration import BENTOML_VERSION
Expand Down Expand Up @@ -41,7 +41,7 @@
class RestApiClient:
def __init__(self, endpoint: str, api_token: str) -> None:
self.endpoint = endpoint
self.session = requests.Session()
self.session = httpx.Client()
self.session.headers.update(
{
"X-YATAI-API-TOKEN": api_token,
Expand All @@ -50,15 +50,15 @@ def __init__(self, endpoint: str, api_token: str) -> None:
}
)

def _is_not_found(self, resp: requests.Response) -> bool:
def _is_not_found(self, resp: httpx.Response) -> bool:
# We used to return 400 for record not found, handle both cases
return (
resp.status_code == 404
or resp.status_code == 400
and "record not found" in resp.text
)

def _check_resp(self, resp: requests.Response) -> None:
def _check_resp(self, resp: httpx.Response) -> None:
if resp.status_code != 200:
raise CloudRESTApiClientError(
f"request failed with status code {resp.status_code}: {resp.text}"
Expand Down Expand Up @@ -96,7 +96,7 @@ def create_bento_repository(
self, req: CreateBentoRepositorySchema
) -> BentoRepositorySchema:
url = urljoin(self.endpoint, "/api/v1/bento_repositories")
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoRepositorySchema)

Expand All @@ -117,7 +117,7 @@ def create_bento(
url = urljoin(
self.endpoint, f"/api/v1/bento_repositories/{bento_repository_name}/bentos"
)
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -128,7 +128,7 @@ def update_bento(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand Down Expand Up @@ -175,7 +175,7 @@ def presign_bento_multipart_upload_url(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/presign_multipart_upload_url",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -189,7 +189,7 @@ def complete_bento_multipart_upload(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/complete_multipart_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -211,7 +211,7 @@ def finish_upload_bento(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/finish_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -224,7 +224,7 @@ def upload_bento(
)
resp = self.session.put(
url,
data=data,
content=data,
headers=dict(
self.session.headers, **{"Content-Type": "application/octet-stream"}
),
Expand All @@ -234,14 +234,14 @@ def upload_bento(

def download_bento(
self, bento_repository_name: str, version: str
) -> requests.Response:
) -> httpx.Response:
url = urljoin(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/download",
)
resp = self.session.get(url, stream=True)
self._check_resp(resp)
return resp
with self.session.stream("GET", url) as resp:
self._check_resp(resp)
return resp

def get_model_repository(
self, model_repository_name: str
Expand All @@ -259,7 +259,7 @@ def create_model_repository(
self, req: CreateModelRepositorySchema
) -> ModelRepositorySchema:
url = urljoin(self.endpoint, "/api/v1/model_repositories")
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelRepositorySchema)

Expand All @@ -280,7 +280,7 @@ def create_model(
url = urljoin(
self.endpoint, f"/api/v1/model_repositories/{model_repository_name}/models"
)
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand Down Expand Up @@ -327,7 +327,7 @@ def presign_model_multipart_upload_url(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/presign_multipart_upload_url",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -341,7 +341,7 @@ def complete_model_multipart_upload(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/complete_multipart_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -363,7 +363,7 @@ def finish_upload_model(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/finish_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -376,7 +376,7 @@ def upload_model(
)
resp = self.session.put(
url,
data=data,
content=data,
headers=dict(
self.session.headers, **{"Content-Type": "application/octet-stream"}
),
Expand All @@ -386,14 +386,14 @@ def upload_model(

def download_model(
self, model_repository_name: str, version: str
) -> requests.Response:
) -> httpx.Response:
url = urljoin(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/download",
)
resp = self.session.get(url, stream=True)
self._check_resp(resp)
return resp
with self.session.stream("GET", url) as resp:
self._check_resp(resp)
return resp

def get_bento_repositories_list(
self, bento_repository_name: str
Expand Down Expand Up @@ -435,7 +435,7 @@ def create_deployment(
self, cluster_name: str, create_schema: CreateDeploymentSchema
) -> DeploymentSchema | None:
url = urljoin(self.endpoint, f"/api/v1/clusters/{cluster_name}/deployments")
resp = self.session.post(url, data=schema_to_json(create_schema))
resp = self.session.post(url, content=schema_to_json(create_schema))
self._check_resp(resp)
return schema_from_json(resp.text, DeploymentSchema)

Expand Down Expand Up @@ -463,7 +463,7 @@ def update_deployment(
self.endpoint,
f"/api/v1/clusters/{cluster_name}/namespaces/{kube_namespace}/deployments/{deployment_name}",
)
resp = self.session.patch(url, data=schema_to_json(update_schema))
resp = self.session.patch(url, content=schema_to_json(update_schema))
if self._is_not_found(resp):
return None
self._check_resp(resp)
Expand Down

0 comments on commit f6444f0

Please sign in to comment.