Skip to content

Commit

Permalink
use httpx in cloud clients
Browse files Browse the repository at this point in the history
Co-Authored-By: Judah Rand <17158624+judahrand@users.noreply.github.com>
  • Loading branch information
2 people authored and Haivilo committed Dec 25, 2023
1 parent 5e57543 commit 752c1c2
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 112 deletions.
87 changes: 46 additions & 41 deletions src/bentoml/_internal/cloud/bentocloud.py
Expand Up @@ -10,7 +10,7 @@
from tempfile import NamedTemporaryFile

import fs
import requests
import httpx
from rich.live import Live
from simple_di import Provide
from simple_di import inject
Expand Down Expand Up @@ -264,7 +264,7 @@ def filter_(
)
try:
if presigned_upload_url is not None:
resp = requests.put(presigned_upload_url, data=tar_io)
resp = httpx.put(presigned_upload_url, content=tar_io)
if resp.status_code != 200:
finish_req = FinishUploadBentoSchema(
status=BentoUploadStatus.FAILED,
Expand Down Expand Up @@ -321,8 +321,8 @@ def chunk_upload(
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = requests.put(
remote_bento.presigned_upload_url, data=chunk_io
resp = httpx.put(
remote_bento.presigned_upload_url, content=chunk_io
)
if resp.status_code != 200:
return FinishUploadBentoSchema(
Expand Down Expand Up @@ -510,27 +510,28 @@ def pull_model(model_tag: Tag):
name, version
)
presigned_download_url = remote_bento.presigned_download_url
response = requests.get(presigned_download_url, stream=True)

if response.status_code != 200:
raise BentoMLException(
f'Failed to download bento "{_tag}": {response.text}'
)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with NamedTemporaryFile() as tar_file:
self.transmission_progress.update(
download_task_id,
completed=0,
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_content(block_size):
with httpx.stream("GET", presigned_download_url) as response:
if response.status_code != 200:
raise BentoMLException(
f'Failed to download bento "{_tag}": {response.text}'
)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
self.transmission_progress.update(
download_task_id, advance=len(data)
download_task_id,
completed=0,
total=total_size_in_bytes,
visible=True,
)
tar_file.write(data)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_bytes(block_size):
self.transmission_progress.update(
download_task_id, advance=len(data)
)
tar_file.write(data)

self.log_progress.add_task(
f'[bold green]Finished downloading all bento "{_tag}" files'
)
Expand Down Expand Up @@ -707,7 +708,7 @@ def io_cb(x: int):
)
try:
if presigned_upload_url is not None:
resp = requests.put(presigned_upload_url, data=tar_io)
resp = httpx.put(presigned_upload_url, content=tar_io)
if resp.status_code != 200:
finish_req = FinishUploadModelSchema(
status=ModelUploadStatus.FAILED,
Expand Down Expand Up @@ -765,8 +766,8 @@ def chunk_upload(
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = requests.put(
remote_model.presigned_upload_url, data=chunk_io
resp = httpx.put(
remote_model.presigned_upload_url, content=chunk_io
)
if resp.status_code != 200:
return FinishUploadModelSchema(
Expand Down Expand Up @@ -835,6 +836,7 @@ def chunk_upload(
version=version,
req=finish_req,
)

if finish_req.status != ModelUploadStatus.SUCCESS:
self.log_progress.add_task(
f'[bold red]Failed pushing model "{model.tag}" : {finish_req.reason}'
Expand Down Expand Up @@ -958,25 +960,28 @@ def _do_pull_model(
)
presigned_download_url = remote_model.presigned_download_url

response = requests.get(presigned_download_url, stream=True)
if response.status_code != 200:
raise BentoMLException(
f'Failed to download model "{_tag}": {response.text}'
with NamedTemporaryFile() as tar_file:
with httpx.stream("GET", presigned_download_url) as response:
if response.status_code != 200:
raise BentoMLException(
f'Failed to download model "{_tag}": {response.text}'
)

total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
self.transmission_progress.update(
download_task_id,
description=f'Downloading model "{_tag}"',
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_bytes(block_size):
self.transmission_progress.update(
download_task_id, advance=len(data)
)
tar_file.write(data)

total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with NamedTemporaryFile() as tar_file:
self.transmission_progress.update(
download_task_id,
description=f'Downloading model "{_tag}"',
total=total_size_in_bytes,
visible=True,
)
self.transmission_progress.start_task(download_task_id)
for data in response.iter_content(block_size):
self.transmission_progress.update(download_task_id, advance=len(data))
tar_file.write(data)
self.log_progress.add_task(
f'[bold green]Finished downloading model "{_tag}" files'
)
Expand Down
60 changes: 30 additions & 30 deletions src/bentoml/_internal/cloud/client.py
Expand Up @@ -3,8 +3,8 @@
import logging
import typing as t
from urllib.parse import urljoin
import httpx

import requests

from ...exceptions import CloudRESTApiClientError
from ..configuration import BENTOML_VERSION
Expand Down Expand Up @@ -43,19 +43,19 @@


class BaseRestApiClient:
def __init__(self, endpoint: str, session: requests.Session) -> None:
def __init__(self, endpoint: str, session: httpx.Client) -> None:
self.endpoint = endpoint
self.session = session

def _is_not_found(self, resp: requests.Response) -> bool:
def _is_not_found(self, resp: httpx.Response) -> bool:
# We used to return 400 for record not found, handle both cases
return (
resp.status_code == 404
or resp.status_code == 400
and "record not found" in resp.text
)

def _check_resp(self, resp: requests.Response) -> None:
def _check_resp(self, resp: httpx.Response) -> None:
if resp.status_code != 200:
raise CloudRESTApiClientError(
f"request failed with status code {resp.status_code}: {resp.text}"
Expand Down Expand Up @@ -95,7 +95,7 @@ def create_bento_repository(
self, req: CreateBentoRepositorySchema
) -> BentoRepositorySchema:
url = urljoin(self.endpoint, "/api/v1/bento_repositories")
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoRepositorySchema)

Expand All @@ -116,7 +116,7 @@ def create_bento(
url = urljoin(
self.endpoint, f"/api/v1/bento_repositories/{bento_repository_name}/bentos"
)
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -127,7 +127,7 @@ def update_bento(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand Down Expand Up @@ -174,7 +174,7 @@ def presign_bento_multipart_upload_url(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/presign_multipart_upload_url",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -188,7 +188,7 @@ def complete_bento_multipart_upload(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/complete_multipart_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -210,7 +210,7 @@ def finish_upload_bento(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/finish_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, BentoSchema)

Expand All @@ -223,7 +223,7 @@ def upload_bento(
)
resp = self.session.put(
url,
data=data,
content=data,
headers=dict(
self.session.headers, **{"Content-Type": "application/octet-stream"}
),
Expand All @@ -233,14 +233,14 @@ def upload_bento(

def download_bento(
self, bento_repository_name: str, version: str
) -> requests.Response:
) -> httpx.Response:
url = urljoin(
self.endpoint,
f"/api/v1/bento_repositories/{bento_repository_name}/bentos/{version}/download",
)
resp = self.session.get(url, stream=True)
self._check_resp(resp)
return resp
with self.session.stream("GET", url) as resp:
self._check_resp(resp)
return resp

def get_model_repository(
self, model_repository_name: str
Expand All @@ -258,7 +258,7 @@ def create_model_repository(
self, req: CreateModelRepositorySchema
) -> ModelRepositorySchema:
url = urljoin(self.endpoint, "/api/v1/model_repositories")
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelRepositorySchema)

Expand All @@ -279,7 +279,7 @@ def create_model(
url = urljoin(
self.endpoint, f"/api/v1/model_repositories/{model_repository_name}/models"
)
resp = self.session.post(url, data=schema_to_json(req))
resp = self.session.post(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand Down Expand Up @@ -326,7 +326,7 @@ def presign_model_multipart_upload_url(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/presign_multipart_upload_url",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -340,7 +340,7 @@ def complete_model_multipart_upload(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/complete_multipart_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -362,7 +362,7 @@ def finish_upload_model(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/finish_upload",
)
resp = self.session.patch(url, data=schema_to_json(req))
resp = self.session.patch(url, content=schema_to_json(req))
self._check_resp(resp)
return schema_from_json(resp.text, ModelSchema)

Expand All @@ -375,7 +375,7 @@ def upload_model(
)
resp = self.session.put(
url,
data=data,
content=data,
headers=dict(
self.session.headers, **{"Content-Type": "application/octet-stream"}
),
Expand All @@ -385,14 +385,14 @@ def upload_model(

def download_model(
self, model_repository_name: str, version: str
) -> requests.Response:
) -> httpx.Response:
url = urljoin(
self.endpoint,
f"/api/v1/model_repositories/{model_repository_name}/models/{version}/download",
)
resp = self.session.get(url, stream=True)
self._check_resp(resp)
return resp
with self.session.stream("GET", url) as resp:
self._check_resp(resp)
return resp

def get_bento_repositories_list(
self, bento_repository_name: str
Expand Down Expand Up @@ -444,7 +444,7 @@ def create_deployment(
self, cluster_name: str, create_schema: CreateDeploymentSchemaV1
) -> DeploymentFullSchema | None:
url = urljoin(self.endpoint, f"/api/v1/clusters/{cluster_name}/deployments")
resp = self.session.post(url, data=schema_to_json(create_schema))
resp = self.session.post(url, content=schema_to_json(create_schema))
self._check_resp(resp)
return schema_from_json(resp.text, DeploymentFullSchema)

Expand Down Expand Up @@ -472,7 +472,7 @@ def update_deployment(
self.endpoint,
f"/api/v1/clusters/{cluster_name}/namespaces/{kube_namespace}/deployments/{deployment_name}",
)
resp = self.session.patch(url, data=schema_to_json(update_schema))
resp = self.session.patch(url, content=schema_to_json(update_schema))
if self._is_not_found(resp):
return None
self._check_resp(resp)
Expand Down Expand Up @@ -544,7 +544,7 @@ def create_deployment(
) -> DeploymentFullSchemaV2:
url = urljoin(self.endpoint, "/api/v2/deployments")
resp = self.session.post(
url, data=schema_to_json(create_schema), params={"cluster": cluster_name}
url, content=schema_to_json(create_schema), params={"cluster": cluster_name}
)
self._check_resp(resp)
return schema_from_json(resp.text, DeploymentFullSchemaV2)
Expand All @@ -560,7 +560,7 @@ def update_deployment(
f"/api/v2/deployments/{deployment_name}",
)
data = schema_to_json(update_schema)
resp = self.session.put(url, data=data, params={"cluster": cluster_name})
resp = self.session.put(url, content=data, params={"cluster": cluster_name})
if self._is_not_found(resp):
return None
self._check_resp(resp)
Expand Down Expand Up @@ -632,7 +632,7 @@ def delete_deployment(

class RestApiClient:
def __init__(self, endpoint: str, api_token: str) -> None:
self.session = requests.Session()
self.session = httpx.Client()
self.session.headers.update(
{
"X-YATAI-API-TOKEN": api_token,
Expand Down

0 comments on commit 752c1c2

Please sign in to comment.