-
Notifications
You must be signed in to change notification settings - Fork 12
/
assigner_impl_test.py
183 lines (139 loc) · 6.49 KB
/
assigner_impl_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import asyncio
from unittest.mock import call
from collections import defaultdict
from typing import Dict, Set
from asynctest.mock import MagicMock, CoroutineMock
import pytest
from google.cloud.pubsublite.internal.wire.assigner import Assigner
from google.cloud.pubsublite.internal.wire.assigner_impl import AssignerImpl
from google.cloud.pubsublite.internal.wire.connection import Connection, ConnectionFactory
from google.api_core.exceptions import InternalServerError
from google.cloud.pubsublite.partition import Partition
from google.cloud.pubsublite_v1.types.subscriber import PartitionAssignmentRequest, InitialPartitionAssignmentRequest, \
PartitionAssignment, PartitionAssignmentAck
from google.cloud.pubsublite.testing.test_utils import make_queue_waiter
from google.cloud.pubsublite.internal.wire.retrying_connection import _MIN_BACKOFF_SECS
# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
@pytest.fixture()
def default_connection():
conn = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment])
conn.__aenter__.return_value = conn
return conn
@pytest.fixture()
def connection_factory(default_connection):
factory = MagicMock(spec=ConnectionFactory[PartitionAssignmentRequest, PartitionAssignment])
factory.new.return_value = default_connection
return factory
@pytest.fixture()
def initial_request():
return PartitionAssignmentRequest(initial=InitialPartitionAssignmentRequest(subscription="mysub"))
class QueuePair:
called: asyncio.Queue
results: asyncio.Queue
def __init__(self):
self.called = asyncio.Queue()
self.results = asyncio.Queue()
@pytest.fixture
def sleep_queues() -> Dict[float, QueuePair]:
return defaultdict(QueuePair)
@pytest.fixture
def asyncio_sleep(monkeypatch, sleep_queues):
"""Requests.get() mocked to return {'mock_key':'mock_response'}."""
mock = CoroutineMock()
monkeypatch.setattr(asyncio, "sleep", mock)
async def sleeper(delay: float):
await make_queue_waiter(sleep_queues[delay].called, sleep_queues[delay].results)(delay)
mock.side_effect = sleeper
return mock
@pytest.fixture()
def assigner(connection_factory, initial_request):
return AssignerImpl(initial_request.initial, connection_factory)
def as_response(partitions: Set[Partition]):
req = PartitionAssignment()
req.partitions = [partition.value for partition in partitions]
return req
def ack_request():
return PartitionAssignmentRequest(ack=PartitionAssignmentAck())
async def test_basic_assign(
assigner: Assigner, default_connection, initial_request):
write_called_queue = asyncio.Queue()
write_result_queue = asyncio.Queue()
default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue)
read_called_queue = asyncio.Queue()
read_result_queue = asyncio.Queue()
default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue)
write_result_queue.put_nowait(None)
async with assigner:
# Set up connection
await write_called_queue.get()
await read_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request)])
# Wait for the first assignment
assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
assert not assign_fut1.done()
partitions = {Partition(2), Partition(7)}
# Send the first assignment.
await read_result_queue.put(as_response(partitions=partitions))
assert (await assign_fut1) == partitions
# Get the next assignment: should send an ack on the stream
assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
await write_called_queue.get()
await write_result_queue.put(None)
default_connection.write.assert_has_calls([call(initial_request), call(ack_request())])
partitions = {Partition(5)}
# Send the second assignment.
await read_called_queue.get()
await read_result_queue.put(as_response(partitions=partitions))
assert (await assign_fut2) == partitions
async def test_restart(
assigner: Assigner, default_connection, connection_factory, initial_request, asyncio_sleep, sleep_queues):
write_called_queue = asyncio.Queue()
write_result_queue = asyncio.Queue()
default_connection.write.side_effect = make_queue_waiter(write_called_queue, write_result_queue)
read_called_queue = asyncio.Queue()
read_result_queue = asyncio.Queue()
default_connection.read.side_effect = make_queue_waiter(read_called_queue, read_result_queue)
write_result_queue.put_nowait(None)
async with assigner:
# Set up connection
await write_called_queue.get()
await read_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request)])
# Wait for the first assignment
assign_fut1 = asyncio.ensure_future(assigner.get_assignment())
assert not assign_fut1.done()
partitions = {Partition(2), Partition(7)}
# Send the first assignment.
await read_result_queue.put(as_response(partitions=partitions))
await read_called_queue.get()
assert (await assign_fut1) == partitions
# Get the next assignment: should attempt to send an ack on the stream
assign_fut2 = asyncio.ensure_future(assigner.get_assignment())
await write_called_queue.get()
default_connection.write.assert_has_calls([call(initial_request), call(ack_request())])
# Set up the next connection
conn2 = MagicMock(spec=Connection[PartitionAssignmentRequest, PartitionAssignment])
conn2.__aenter__.return_value = conn2
connection_factory.new.return_value = conn2
write_called_queue_2 = asyncio.Queue()
write_result_queue_2 = asyncio.Queue()
conn2.write.side_effect = make_queue_waiter(write_called_queue_2, write_result_queue_2)
read_called_queue_2 = asyncio.Queue()
read_result_queue_2 = asyncio.Queue()
conn2.read.side_effect = make_queue_waiter(read_called_queue_2, read_result_queue_2)
# Fail the connection by failing the write call.
await write_result_queue.put(InternalServerError("failed"))
await sleep_queues[_MIN_BACKOFF_SECS].called.get()
await sleep_queues[_MIN_BACKOFF_SECS].results.put(None)
# Reinitialize
await write_called_queue_2.get()
write_result_queue_2.put_nowait(None)
conn2.write.assert_has_calls([call(initial_request)])
partitions = {Partition(5)}
# Send the second assignment on the new connection.
await read_called_queue_2.get()
await read_result_queue_2.put(as_response(partitions=partitions))
assert (await assign_fut2) == partitions
# No ack call ever made.
conn2.write.assert_has_calls([call(initial_request)])