Skip to content

Commit

Permalink
fix: incorrect URL joining. (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoido committed Feb 8, 2023
1 parent 86ecbf9 commit 27cd79a
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 78 deletions.
2 changes: 0 additions & 2 deletions src/h2o_discovery/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class AsyncClient:
"""

def __init__(self, uri: str):
self._uri = uri

self._environment_uri = client.get_environment_uri(uri)
self._services_uri = client.get_services_uri(uri)
self._clients_uri = client.get_clients_uri(uri)
Expand Down
14 changes: 6 additions & 8 deletions src/h2o_discovery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from h2o_discovery import model


ENVIRONMENT_ENDPOINT = "/v1/environment"
SERVICES_ENDPOINT = "/v1/services"
CLIENTS_ENDPOINT = "/v1/clients"
ENVIRONMENT_ENDPOINT = "v1/environment"
SERVICES_ENDPOINT = "v1/services"
CLIENTS_ENDPOINT = "v1/clients"


def get_environment_uri(uri: str) -> str:
return urllib.parse.urljoin(uri, ENVIRONMENT_ENDPOINT)
return urllib.parse.urljoin(uri + "/", ENVIRONMENT_ENDPOINT)


def get_services_uri(uri: str) -> str:
return urllib.parse.urljoin(uri, SERVICES_ENDPOINT)
return urllib.parse.urljoin(uri + "/", SERVICES_ENDPOINT)


def get_clients_uri(uri: str) -> str:
return urllib.parse.urljoin(uri, CLIENTS_ENDPOINT)
return urllib.parse.urljoin(uri + "/", CLIENTS_ENDPOINT)


class Client:
Expand All @@ -31,8 +31,6 @@ class Client:
"""

def __init__(self, uri: str):
self._uri = uri

self._environment_uri = get_environment_uri(uri)
self._services_uri = get_services_uri(uri)
self._clients_uri = get_clients_uri(uri)
Expand Down
8 changes: 4 additions & 4 deletions src/h2o_discovery/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def determine_uri(
raise ValueError("cannot specify both discovery and environment")

if discovery_address is not None:
return discovery_address
return discovery_address.rstrip("/")

if environment is not None:
return _discovery_uri_from_environment(environment)

discovery_address = os.environ.get("H2O_CLOUD_DISCOVERY")
if discovery_address is not None:
return discovery_address
return discovery_address.rstrip("/")

environment = os.environ.get("H2O_CLOUD_ENVIRONMENT")
if environment is not None:
Expand All @@ -35,5 +35,5 @@ def determine_uri(
)


def _discovery_uri_from_environment(environment):
return urllib.parse.urljoin(environment, _WELL_KNOWN_PATH)
def _discovery_uri_from_environment(environment: str):
return urllib.parse.urljoin(environment + "/", _WELL_KNOWN_PATH)
164 changes: 130 additions & 34 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,28 @@


@respx.mock
def test_client_get_environment():
def test_client_get_environment_internal():
# Given
route = respx.get("https://test.example.com/v1/environment").respond(
route = respx.get("http://test.example.com:1234/v1/environment").respond(
json=ENVIRONMENT_JSON
)
cl = client.Client("https://test.example.com")
cl = client.Client("http://test.example.com:1234")

# When
env = cl.get_environment()

# Then
assert route.called
assert env == EXPECTED_ENVIRONMENT_DATA


@respx.mock
def test_client_get_environment_public():
# Given
route = respx.get(
"https://test.example.com/.ai.h2o.cloud.discovery/v1/environment"
).respond(json=ENVIRONMENT_JSON)
cl = client.Client("https://test.example.com/.ai.h2o.cloud.discovery")

# When
env = cl.get_environment()
Expand All @@ -25,12 +41,12 @@ def test_client_get_environment():

@respx.mock
@pytest.mark.asyncio
async def test_async_client_get_environment():
async def test_async_client_get_environment_internal():
# Given
route = respx.get("https://test.example.com/v1/environment").respond(
route = respx.get("http://test.example.com:1234/v1/environment").respond(
json=ENVIRONMENT_JSON
)
cl = async_client.AsyncClient("https://test.example.com")
cl = async_client.AsyncClient("http://test.example.com:1234")

# When
env = await cl.get_environment()
Expand All @@ -41,83 +57,156 @@ async def test_async_client_get_environment():


@respx.mock
def test_client_list_services():
@pytest.mark.asyncio
async def test_async_client_get_environment_public():
# Given
route = respx.get("https://test.example.com/v1/services")
route = respx.get(
"https://test.example.com/.ai.h2o.cloud.discovery/v1/environment"
).respond(json=ENVIRONMENT_JSON)
cl = async_client.AsyncClient("https://test.example.com/.ai.h2o.cloud.discovery")

# When
env = await cl.get_environment()

# Then
assert route.called
assert env == EXPECTED_ENVIRONMENT_DATA


@respx.mock
def test_client_list_services_internal():
# Given
route = respx.get("http://test.example.com:1234/v1/services")
route.side_effect = SERVICES_RESPONSES

cl = client.Client("https://test.example.com")
cl = client.Client("http://test.example.com:1234")

# When
services = cl.list_services()

# Then
assert services == EXPECTED_SERVICES_RECORDS
assert route.call_count == 3
assert not route.calls[0].request.url.query
assert route.calls[1].request.url.query == b"pageToken=next-page-token-1"
assert route.calls[2].request.url.query == b"pageToken=next-page-token-2"
_assert_pagination_api_calls(route)


@respx.mock
def test_client_list_services_public():
# Given
route = respx.get("https://test.example.com/.ai.h2o.cloud.discovery/v1/services")
route.side_effect = SERVICES_RESPONSES

cl = client.Client("https://test.example.com/.ai.h2o.cloud.discovery")

# When
services = cl.list_services()

# Then
assert services == EXPECTED_SERVICES_RECORDS
_assert_pagination_api_calls(route)


@respx.mock
@pytest.mark.asyncio
async def test_async_client_list_services():
async def test_async_client_list_services_internal():
# Given
route = respx.get("https://test.example.com/v1/services")
route = respx.get("http://test.example.com:1234/v1/services")
route.side_effect = SERVICES_RESPONSES

cl = async_client.AsyncClient("https://test.example.com")
cl = async_client.AsyncClient("http://test.example.com:1234")

# When
services = await cl.list_services()

# Then
assert services == EXPECTED_SERVICES_RECORDS
assert route.call_count == 3
assert not route.calls[0].request.url.query
assert route.calls[1].request.url.query == b"pageToken=next-page-token-1"
assert route.calls[2].request.url.query == b"pageToken=next-page-token-2"
_assert_pagination_api_calls(route)


@respx.mock
def test_client_list_clients():
@pytest.mark.asyncio
async def test_async_client_list_services_pubclic():
# Given
route = respx.get("https://test.example.com/v1/clients")
route = respx.get("https://test.example.com/.ai.h2o.cloud.discovery/v1/services")
route.side_effect = SERVICES_RESPONSES

cl = async_client.AsyncClient("https://test.example.com/.ai.h2o.cloud.discovery")

# When
services = await cl.list_services()

# Then
assert services == EXPECTED_SERVICES_RECORDS
_assert_pagination_api_calls(route)


@respx.mock
def test_client_list_clients_internal():
# Given
route = respx.get("http://test.example.com:1234/v1/clients")
route.side_effect = CLIENTS_RESPONSES

cl = client.Client("https://test.example.com")
cl = client.Client("http://test.example.com:1234")

# When
clients = cl.list_clients()

# Then

assert clients == EXPECTED_CLIENTS_RECORDS
assert route.call_count == 3
assert not route.calls[0].request.url.query
assert route.calls[1].request.url.query == b"pageToken=next-page-token-1"
assert route.calls[2].request.url.query == b"pageToken=next-page-token-2"
_assert_pagination_api_calls(route)


@respx.mock
def test_client_list_clients_public():
# Given
route = respx.get("https://test.example.com/.ai.h2o.cloud.discovery/v1/clients")
route.side_effect = CLIENTS_RESPONSES

cl = client.Client("https://test.example.com/.ai.h2o.cloud.discovery")

# When
clients = cl.list_clients()

# Then

assert clients == EXPECTED_CLIENTS_RECORDS
_assert_pagination_api_calls(route)


@respx.mock
@pytest.mark.asyncio
async def test_async_client_list_clients():
async def test_async_client_list_clients_internal():
# Given
route = respx.get("https://test.example.com/v1/clients")
route = respx.get("https://test.example.com:1234/v1/clients")
route.side_effect = CLIENTS_RESPONSES

cl = async_client.AsyncClient("https://test.example.com")
cl = async_client.AsyncClient("https://test.example.com:1234")

# When
clients = await cl.list_clients()

# Then

assert clients == EXPECTED_CLIENTS_RECORDS
assert route.call_count == 3
assert not route.calls[0].request.url.query
assert route.calls[1].request.url.query == b"pageToken=next-page-token-1"
assert route.calls[2].request.url.query == b"pageToken=next-page-token-2"
_assert_pagination_api_calls(route)


@respx.mock
@pytest.mark.asyncio
async def test_async_client_list_clients_public():
# Given
route = respx.get("https://test.example.com/.ai.h2o.cloud.discovery/v1/clients")
route.side_effect = CLIENTS_RESPONSES

cl = async_client.AsyncClient("https://test.example.com/.ai.h2o.cloud.discovery")

# When
clients = await cl.list_clients()

# Then

assert clients == EXPECTED_CLIENTS_RECORDS
_assert_pagination_api_calls(route)


@respx.mock
Expand Down Expand Up @@ -174,6 +263,13 @@ async def test_async_client_list_clients_can_handle_empty_response():
assert services == []


def _assert_pagination_api_calls(route):
assert route.call_count == 3
assert not route.calls[0].request.url.query
assert route.calls[1].request.url.query == b"pageToken=next-page-token-1"
assert route.calls[2].request.url.query == b"pageToken=next-page-token-2"


ENVIRONMENT_JSON = {
"environment": {
"h2oCloudEnvironment": "https://cloud.fbi.com",
Expand Down

0 comments on commit 27cd79a

Please sign in to comment.