Skip to content

Commit

Permalink
Allow to pass an instance of ClientSession as an argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kentzo committed Oct 23, 2017
1 parent 182c427 commit ada2ec4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 57 deletions.
5 changes: 5 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
X.X.X
-----

- Allow to pass an instance of ClientSession as an argument

0.6.0
-----

Expand Down
18 changes: 7 additions & 11 deletions raven_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AioHttpTransportBase(

def __init__(self, parsed_url=None, *, verify_ssl=True, resolve=True,
timeout=defaults.TIMEOUT,
keepalive=True, family=socket.AF_INET, loop=None):
keepalive=True, family=socket.AF_INET, client_session=None, loop=None):
self._resolve = resolve
self._keepalive = keepalive
self._family = family
Expand All @@ -53,7 +53,9 @@ def __init__(self, parsed_url=None, *, verify_ssl=True, resolve=True,
else:
super().__init__(parsed_url, timeout, verify_ssl)

if self.keepalive:
if client_session:
self._client_session = client_session
else:
self._client_session = self._client_session_factory()

self._closing = False
Expand All @@ -80,15 +82,10 @@ def _client_session_factory(self):

@asyncio.coroutine
def _do_send(self, url, data, headers, success_cb, failure_cb):
if self.keepalive:
session = self._client_session
else:
session = self._client_session_factory()

resp = None

try:
resp = yield from session.post(
resp = yield from self._client_session.post(
url,
data=data,
compress=False,
Expand Down Expand Up @@ -118,8 +115,6 @@ def _do_send(self, url, data, headers, success_cb, failure_cb):
finally:
if resp is not None:
resp.release()
if not self.keepalive:
yield from session.close()

@abc.abstractmethod
def _async_send(self, url, data, headers, success_cb, failure_cb): # pragma: no cover
Expand All @@ -146,8 +141,9 @@ def _close_coro(self, *, timeout=None):
except asyncio.TimeoutError:
pass
finally:
if self.keepalive:
if self._client_session:
yield from self._client_session.close()
self._client_session = None

def close(self, *, timeout=None):
if self._closing:
Expand Down
33 changes: 10 additions & 23 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from unittest import mock

import aiohttp
import pytest

from raven_aiohttp import QueuedAioHttpTransport
Expand All @@ -28,34 +29,20 @@ def test_basic(fake_server, raven_client, wait):


@asyncio.coroutine
def test_no_keepalive(fake_server, raven_client, wait):
transport = QueuedAioHttpTransport(keepalive=False)
assert not hasattr(transport, '_client_session')
yield from transport.close()

def test_custom_client_session(fake_server, raven_client, wait):
server = yield from fake_server()

client, transport = raven_client(server, QueuedAioHttpTransport)
transport._keepalive = False
session = transport._client_session

def _client_session_factory():
return session

with mock.patch(
'raven_aiohttp.QueuedAioHttpTransport._client_session_factory',
side_effect=_client_session_factory,
):
try:
1 / 0
except ZeroDivisionError:
client.captureException()
session = aiohttp.ClientSession()
client, transport = raven_client(server, partial(QueuedAioHttpTransport, client_session=session))

yield from wait(transport)
try:
1 / 0
except ZeroDivisionError:
client.captureException()

assert session.closed
yield from wait(transport)

assert server.hits[200] == 1
assert server.hits[200] == 1


@asyncio.coroutine
Expand Down
34 changes: 11 additions & 23 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import logging
from functools import partial
from unittest import mock

import aiohttp
import pytest

from raven_aiohttp import AioHttpTransport
Expand All @@ -27,34 +29,20 @@ def test_basic(fake_server, raven_client, wait):


@asyncio.coroutine
def test_no_keepalive(fake_server, raven_client, wait):
transport = AioHttpTransport(keepalive=False)
assert not hasattr(transport, '_client_session')
yield from transport.close()

def test_custom_client_session(fake_server, raven_client, wait):
server = yield from fake_server()

client, transport = raven_client(server, AioHttpTransport)
transport._keepalive = False
session = transport._client_session

def _client_session_factory():
return session
session = aiohttp.ClientSession()
client, transport = raven_client(server, partial(AioHttpTransport, client_session=session))

with mock.patch(
'raven_aiohttp.AioHttpTransport._client_session_factory',
side_effect=_client_session_factory,
):
try:
1 / 0
except ZeroDivisionError:
client.captureException()

yield from wait(transport)
try:
1 / 0
except ZeroDivisionError:
client.captureException()

assert session.closed
yield from wait(transport)

assert server.hits[200] == 1
assert server.hits[200] == 1


@asyncio.coroutine
Expand Down

0 comments on commit ada2ec4

Please sign in to comment.