Skip to content

Commit

Permalink
refactor: Implement default response handler method and added test wh…
Browse files Browse the repository at this point in the history
…en JSON decode error occurs
  • Loading branch information
David Blain committed Apr 29, 2024
1 parent 28a240a commit e32422e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 11 deletions.
10 changes: 9 additions & 1 deletion airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from __future__ import annotations

import json
from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import quote, urljoin, urlparse

Expand Down Expand Up @@ -51,6 +53,12 @@
from airflow.models import Connection


def default_response_handler(response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None) -> Any:
with suppress(JSONDecodeError):
return response.json()
return response


class CallableResponseHandler(ResponseHandler):
"""
CallableResponseHandler executes the passed callable_function with response as parameter.
Expand Down Expand Up @@ -271,7 +279,7 @@ async def run(
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
] = default_response_handler,
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook, default_response_handler
from airflow.providers.microsoft.azure.triggers.msgraph import (
MSGraphTrigger,
ResponseSerializer,
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
] = default_response_handler,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/sensors/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook, default_response_handler
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue

Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
] = default_response_handler,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/azure/triggers/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import pendulum

from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook, default_response_handler
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.module_loading import import_string

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
] = default_response_handler,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand Down
20 changes: 17 additions & 3 deletions tests/providers/microsoft/azure/hooks/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from __future__ import annotations

import asyncio
from json import JSONDecodeError
from unittest.mock import patch

import pytest
from httpx import Response
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from msgraph_core import APIVersion, NationalClouds

from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
from airflow.providers.microsoft.azure.hooks.msgraph import CallableResponseHandler, KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.hooks.msgraph import CallableResponseHandler, KiotaRequestAdapterHook, \
default_response_handler
from tests.providers.microsoft.conftest import (
get_airflow_connection,
load_json,
Expand Down Expand Up @@ -95,19 +98,30 @@ def test_encoded_query_parameters(self):


class TestResponseHandler:
def test_handle_response_async_when_ok(self):
def test_default_response_handler_when_json(self):
users = load_json("resources", "users.json")
response = mock_json_response(200, users)

actual = asyncio.run(
CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async(
CallableResponseHandler(default_response_handler).handle_response_async(
response, None
)
)

assert isinstance(actual, dict)
assert actual == users

def test_default_response_handler_when_not_json(self):
response = mock_json_response(200, JSONDecodeError("", "", 0))

actual = asyncio.run(
CallableResponseHandler(default_response_handler).handle_response_async(
response, None
)
)

assert actual == response

def test_handle_response_async_when_bad_request(self):
response = mock_json_response(400, {})

Expand Down
6 changes: 5 additions & 1 deletion tests/providers/microsoft/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def mock_json_response(status_code, *contents) -> Response:
response.status_code = status_code
if contents:
contents = list(contents)
response.json.side_effect = lambda: contents.pop(0)
side_effect = contents.pop(0)
if isinstance(side_effect, Exception):
response.json.side_effect = side_effect
else:
response.json.side_effect = lambda: side_effect
else:
response.json.return_value = None
return response
Expand Down

0 comments on commit e32422e

Please sign in to comment.