diff --git a/amqpstorm/channel.py b/amqpstorm/channel.py index 72ada6ab..74cdafbb 100644 --- a/amqpstorm/channel.py +++ b/amqpstorm/channel.py @@ -300,9 +300,10 @@ def start_consuming(self, to_tuple=False, auto_decode=True): to_tuple=to_tuple, auto_decode=auto_decode ) - if not self.consumer_tags: - break - sleep(IDLE_WAIT) + if self.consumer_tags: + sleep(IDLE_WAIT) + continue + break def stop_consuming(self): """Stop consuming messages. diff --git a/amqpstorm/tests/unit/channel/channel_build_message_tests.py b/amqpstorm/tests/unit/channel/channel_message_handling_tests.py similarity index 50% rename from amqpstorm/tests/unit/channel/channel_build_message_tests.py rename to amqpstorm/tests/unit/channel/channel_message_handling_tests.py index f1e0b2ae..afaf3f05 100644 --- a/amqpstorm/tests/unit/channel/channel_build_message_tests.py +++ b/amqpstorm/tests/unit/channel/channel_message_handling_tests.py @@ -209,8 +209,12 @@ def test_channel_build_inbound_messages(self): channel._inbound = [deliver, header, body] + messages_consumed = 0 for message in channel.build_inbound_messages(break_on_empty=True): self.assertIsInstance(message, Message) + messages_consumed += 1 + + self.assertEqual(messages_consumed, 1) def test_channel_build_multiple_inbound_messages(self): channel = Channel(0, FakeConnection(), 360) @@ -226,12 +230,12 @@ def test_channel_build_multiple_inbound_messages(self): channel._inbound = [deliver, header, body, deliver, header, body, deliver, header, body, deliver, header, body] - index = 0 + messages_consumed = 0 for message in channel.build_inbound_messages(break_on_empty=True): self.assertIsInstance(message, Message) - index += 1 + messages_consumed += 1 - self.assertEqual(index, 4) + self.assertEqual(messages_consumed, 4) def test_channel_build_large_number_inbound_messages(self): channel = Channel(0, FakeConnection(), 360) @@ -249,9 +253,230 @@ def test_channel_build_large_number_inbound_messages(self): channel._inbound.append(header) channel._inbound.append(body) - index = 0 + messages_consumed = 0 for message in channel.build_inbound_messages(break_on_empty=True): self.assertIsInstance(message, Message) - index += 1 + messages_consumed += 1 + + self.assertEqual(messages_consumed, 10000) + + def test_channel_build_inbound_messages_without_break_on_empty(self): + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver = specification.Basic.Deliver() + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + for _ in range(25): + channel._inbound.append(deliver) + channel._inbound.append(header) + channel._inbound.append(body) + + messages_consumed = 0 + for msg in channel.build_inbound_messages(break_on_empty=False): + messages_consumed += 1 + self.assertIsInstance(msg.body, str) + self.assertEqual(msg.body.encode('utf-8'), message) + if messages_consumed >= 10: + channel.set_state(channel.CLOSED) + self.assertEqual(messages_consumed, 10) + + def test_channel_build_inbound_messages_as_tuple(self): + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver = specification.Basic.Deliver() + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + channel._inbound = [deliver, header, body] + + messages_consumed = 0 + for msg in channel.build_inbound_messages(break_on_empty=True, + to_tuple=True): + self.assertIsInstance(msg, tuple) + self.assertEqual(msg[0], message) + messages_consumed += 1 + + self.assertEqual(messages_consumed, 1) + + +class ChannelProcessDataEventTests(TestFramework): + def test_channel_process_data_events(self): + self.msg = None + + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver = specification.Basic.Deliver(consumer_tag='travis-ci') + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + channel._inbound = [deliver, header, body] + + def callback(msg): + self.msg = msg + + channel._consumer_callbacks['travis-ci'] = callback + channel.process_data_events() + + self.assertIsNotNone(self.msg, 'No message consumed') + self.assertIsInstance(self.msg.body, str) + self.assertEqual(self.msg.body.encode('utf-8'), message) + + def test_channel_process_data_events_as_tuple(self): + self.msg = None + + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver = specification.Basic.Deliver(consumer_tag='travis-ci') + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + channel._inbound = [deliver, header, body] + + def callback(body, channel, method, properties): + self.msg = (body, channel, method, properties) + + channel._consumer_callbacks['travis-ci'] = callback + channel.process_data_events(to_tuple=True) + + self.assertIsNotNone(self.msg, 'No message consumed') + + body, channel, method, properties = self.msg + + self.assertIsInstance(body, bytes) + self.assertIsInstance(channel, Channel) + self.assertIsInstance(method, dict) + self.assertIsInstance(properties, dict) + self.assertEqual(body, message) + + +class ChannelStartConsumingTests(TestFramework): + def test_channel_start_consuming(self): + self.msg = None + consumer_tag = 'travis-ci' + + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver = specification.Basic.Deliver(consumer_tag='travis-ci') + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + channel._inbound = [deliver, header, body] + + def callback(msg): + self.msg = msg + channel.set_state(channel.CLOSED) + + channel.add_consumer_tag(consumer_tag) + channel._consumer_callbacks['travis-ci'] = callback + channel.start_consuming() + + self.assertIsNotNone(self.msg, 'No message consumed') + self.assertIsInstance(self.msg.body, str) + self.assertEqual(self.msg.body.encode('utf-8'), message) + + def test_channel_start_consuming_idle_wait(self): + self.msg = None + consumer_tag = 'travis-ci' + + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + def add_inbound(): + deliver = specification.Basic.Deliver(consumer_tag='travis-ci') + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) + + channel._inbound = [deliver, header, body] + + def callback(msg): + self.msg = msg + channel.set_state(channel.CLOSED) + + channel.add_consumer_tag(consumer_tag) + channel._consumer_callbacks[consumer_tag] = callback + + threading.Timer(function=add_inbound, interval=1).start() + channel.start_consuming() + + self.assertIsNotNone(self.msg, 'No message consumed') + self.assertIsInstance(self.msg.body, str) + self.assertEqual(self.msg.body.encode('utf-8'), message) + + def test_channel_start_consuming_no_consumer_tags(self): + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + channel._consumer_callbacks = ['fake'] + + self.assertIsNone(channel.start_consuming()) + + def test_channel_start_consuming_multiple_callbacks(self): + channel = Channel(0, FakeConnection(), 360) + channel.set_state(channel.OPEN) + + message = self.message.encode('utf-8') + message_len = len(message) + + deliver_one = specification.Basic.Deliver( + consumer_tag='travis-ci-1') + deliver_two = specification.Basic.Deliver( + consumer_tag='travis-ci-2') + deliver_three = specification.Basic.Deliver( + consumer_tag='travis-ci-3') + header = ContentHeader(body_size=message_len) + body = ContentBody(value=message) - self.assertEqual(index, 10000) + channel._inbound = [ + deliver_one, header, body, + deliver_two, header, body, + deliver_three, header, body + ] + + def callback_one(msg): + self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-1') + self.assertIsInstance(msg.body, str) + self.assertEqual(msg.body.encode('utf-8'), message) + + def callback_two(msg): + self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-2') + self.assertIsInstance(msg.body, str) + self.assertEqual(msg.body.encode('utf-8'), message) + + def callback_three(msg): + self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-3') + self.assertIsInstance(msg.body, str) + self.assertEqual(msg.body.encode('utf-8'), message) + channel.set_state(channel.CLOSED) + + channel.add_consumer_tag('travis-ci-1') + channel.add_consumer_tag('travis-ci-2') + channel.add_consumer_tag('travis-ci-3') + channel._consumer_callbacks['travis-ci-1'] = callback_one + channel._consumer_callbacks['travis-ci-2'] = callback_two + channel._consumer_callbacks['travis-ci-3'] = callback_three + + channel.start_consuming() diff --git a/amqpstorm/tests/unit/channel/channel_tests.py b/amqpstorm/tests/unit/channel/channel_tests.py index f23f0a5d..c3ad22ef 100644 --- a/amqpstorm/tests/unit/channel/channel_tests.py +++ b/amqpstorm/tests/unit/channel/channel_tests.py @@ -1,6 +1,4 @@ -from pamqp import ContentHeader from pamqp import specification -from pamqp.body import ContentBody from amqpstorm import Channel from amqpstorm.basic import Basic @@ -45,163 +43,6 @@ def test_channel_id(self): self.assertEqual(int(channel), 1557) - def test_channel_build_inbound_messages(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver() - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [deliver, header, body] - - for msg in channel.build_inbound_messages(break_on_empty=True): - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - - def test_channel_build_inbound_messages_without_break_on_empty(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver() - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - for _ in range(25): - channel._inbound.append(deliver) - channel._inbound.append(header) - channel._inbound.append(body) - - messages_consumed = 0 - for msg in channel.build_inbound_messages(break_on_empty=False): - messages_consumed += 1 - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - if messages_consumed >= 10: - channel.set_state(channel.CLOSED) - self.assertEqual(messages_consumed, 10) - - def test_channel_build_inbound_messages_as_tuple(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver() - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [deliver, header, body] - - for msg in channel.build_inbound_messages(break_on_empty=True, - to_tuple=True): - self.assertIsInstance(msg, tuple) - self.assertEqual(msg[0], message) - - def test_channel_process_data_events(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver(consumer_tag='travis-ci') - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [deliver, header, body] - - def callback(msg): - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - - channel._consumer_callbacks['travis-ci'] = callback - channel.process_data_events() - - def test_channel_process_data_events_as_tuple(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver(consumer_tag='travis-ci') - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [deliver, header, body] - - def callback(body, channel, method, properties): - self.assertIsInstance(body, bytes) - self.assertIsInstance(channel, Channel) - self.assertIsInstance(method, dict) - self.assertIsInstance(properties, dict) - self.assertEqual(body, message) - - channel._consumer_callbacks['travis-ci'] = callback - channel.process_data_events(to_tuple=True) - - def test_channel_start_consuming(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver = specification.Basic.Deliver(consumer_tag='travis-ci') - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [deliver, header, body] - - def callback(msg): - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - channel.set_state(channel.CLOSED) - - channel._consumer_callbacks['travis-ci'] = callback - channel.start_consuming() - - def test_channel_start_consuming_multiple_callbacks(self): - channel = Channel(0, FakeConnection(), 360) - channel.set_state(channel.OPEN) - - message = self.message.encode('utf-8') - message_len = len(message) - - deliver_one = specification.Basic.Deliver(consumer_tag='travis-ci-1') - deliver_two = specification.Basic.Deliver(consumer_tag='travis-ci-2') - header = ContentHeader(body_size=message_len) - body = ContentBody(value=message) - - channel._inbound = [ - deliver_one, header, body, - deliver_two, header, body - ] - - def callback_one(msg): - self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-1') - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - - def callback_two(msg): - self.assertEqual(msg.method.get('consumer_tag'), 'travis-ci-2') - self.assertIsInstance(msg.body, str) - self.assertEqual(msg.body.encode('utf-8'), message) - channel.set_state(channel.CLOSED) - - channel._consumer_callbacks['travis-ci-1'] = callback_one - channel._consumer_callbacks['travis-ci-2'] = callback_two - - channel.start_consuming() - def test_channel_open(self): def on_open_ok(_, frame_out): self.assertIsInstance(frame_out, specification.Channel.Open)