-
Notifications
You must be signed in to change notification settings - Fork 12
/
assigning_subscriber_test.py
127 lines (104 loc) · 5.25 KB
/
assigning_subscriber_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import asyncio
from typing import Callable, Set
from asynctest.mock import MagicMock, call
import pytest
from google.api_core.exceptions import FailedPrecondition
from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage
from google.cloud.pubsublite.cloudpubsub.internal.assigning_subscriber import AssigningSubscriber
from google.cloud.pubsublite.cloudpubsub.subscriber import AsyncSubscriber
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter, wire_queues
# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
def mock_async_context_manager(cm):
cm.__aenter__.return_value = cm
return cm
@pytest.fixture()
def assigner():
return mock_async_context_manager(MagicMock(spec=Assigner))
@pytest.fixture()
def subscriber_factory():
return MagicMock(spec=Callable[[Partition], AsyncSubscriber])
@pytest.fixture()
def subscriber(assigner, subscriber_factory):
return AssigningSubscriber(assigner, subscriber_factory)
async def test_init(subscriber, assigner):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
assigner.__aenter__.assert_called_once()
await assign_queues.called.get()
assigner.get_assignment.assert_called_once()
assigner.__aexit__.assert_called_once()
async def test_initial_assignment(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()
async def test_assigner_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
await assign_queues.results.put(FailedPrecondition("bad assign"))
with pytest.raises(FailedPrecondition):
await subscriber.read()
async def test_assignment_change(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub3 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(
1) else sub2 if partition == Partition(2) else sub3
await assign_queues.results.put({Partition(1), Partition(2)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2))], any_order=True)
sub1.__aenter__.assert_called_once()
sub2.__aenter__.assert_called_once()
await assign_queues.results.put({Partition(1), Partition(3)})
await assign_queues.called.get()
subscriber_factory.assert_has_calls([call(Partition(1)), call(Partition(2)), call(Partition(3))], any_order=True)
sub3.__aenter__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub1.__aexit__.assert_called_once()
sub2.__aexit__.assert_called_once()
sub3.__aexit__.assert_called_once()
async def test_subscriber_failure(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
subscriber_factory.return_value = sub1
await assign_queues.results.put({Partition(1)})
await sub1_queues.called.get()
await sub1_queues.results.put(FailedPrecondition("sub failed"))
with pytest.raises(FailedPrecondition):
await subscriber.read()
async def test_delivery_from_multiple(subscriber, assigner, subscriber_factory):
assign_queues = wire_queues(assigner.get_assignment)
async with subscriber:
await assign_queues.called.get()
sub1 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub2 = mock_async_context_manager(MagicMock(spec=AsyncSubscriber))
sub1_queues = wire_queues(sub1.read)
sub2_queues = wire_queues(sub2.read)
subscriber_factory.side_effect = lambda partition: sub1 if partition == Partition(1) else sub2
await assign_queues.results.put({Partition(1), Partition(2)})
await sub1_queues.results.put(Message(PubsubMessage(message_id="1")._pb, "", 0, None))
await sub2_queues.results.put(Message(PubsubMessage(message_id="2")._pb, "", 0, None))
message_ids: Set[str] = set()
message_ids.add((await subscriber.read()).message_id)
message_ids.add((await subscriber.read()).message_id)
assert message_ids == {"1", "2"}