diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index a4784b386..817176bef 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -340,17 +340,19 @@ def __init__( self._default_host = default_host if auth_request is None: - auth_request_session = requests.Session() + self._auth_request_session = requests.Session() # Using an adapter to make HTTP requests robust to network errors. # This adapter retrys HTTP requests when network errors occur # and the requests seems safely retryable. retry_adapter = requests.adapters.HTTPAdapter(max_retries=3) - auth_request_session.mount("https://", retry_adapter) + self._auth_request_session.mount("https://", retry_adapter) # Do not pass `self` as the session here, as it can lead to # infinite recursion. - auth_request = Request(auth_request_session) + auth_request = Request(self._auth_request_session) + else: + self._auth_request_session = None # Request instance used by internal methods (for example, # credentials.refresh). @@ -533,3 +535,8 @@ def request( def is_mtls(self): """Indicates if the created SSL channel is mutual TLS.""" return self._is_mtls + + def close(self): + if self._auth_request_session is not None: + self._auth_request_session.close() + super(AuthorizedSession, self).close() diff --git a/system_tests/system_tests_sync/test_requests.py b/system_tests/system_tests_sync/test_requests.py index 3ac9179b5..28004848b 100644 --- a/system_tests/system_tests_sync/test_requests.py +++ b/system_tests/system_tests_sync/test_requests.py @@ -32,8 +32,10 @@ def test_authorized_session_with_service_account_and_self_signed_jwt(): # List Pub/Sub Topics through the REST API # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.topics/list - response = session.get("https://pubsub.googleapis.com/v1/projects/{}/topics".format(project_id)) - response.raise_for_status() + url = "https://pubsub.googleapis.com/v1/projects/{}/topics".format(project_id) + with session: + response = session.get(url) + response.raise_for_status() # Check that self-signed JWT was created and is being used assert credentials._jwt_credentials is not None diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index f494c1443..ed9300d76 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -213,7 +213,7 @@ def test_constructor_with_auth_request(self): mock.sentinel.credentials, auth_request=auth_request ) - assert authed_session._auth_request == auth_request + assert authed_session._auth_request is auth_request def test_request_default_timeout(self): credentials = mock.Mock(wraps=CredentialsStub()) @@ -504,3 +504,22 @@ def test_configure_mtls_channel_without_client_cert_env( auth_session.configure_mtls_channel(mock_callback) assert not auth_session.is_mtls mock_callback.assert_not_called() + + def test_close_wo_passed_in_auth_request(self): + authed_session = google.auth.transport.requests.AuthorizedSession( + mock.sentinel.credentials + ) + authed_session._auth_request_session = mock.Mock(spec=["close"]) + + authed_session.close() + + authed_session._auth_request_session.close.assert_called_once_with() + + def test_close_w_passed_in_auth_request(self): + http = mock.create_autospec(requests.Session) + auth_request = google.auth.transport.requests.Request(http) + authed_session = google.auth.transport.requests.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) + + authed_session.close() # no raise