Skip to content

Commit

Permalink
MRG: Merge pull request #625 from octue/improve-questions-asking-and-…
Browse files Browse the repository at this point in the history
…event-handling

Make event handling faster and resilient to missing events
  • Loading branch information
cortadocodes committed Feb 5, 2024
2 parents c4b41f6 + 13c4190 commit 2e6c23f
Show file tree
Hide file tree
Showing 17 changed files with 1,046 additions and 662 deletions.
2 changes: 2 additions & 0 deletions docs/source/asking_questions.rst
Expand Up @@ -89,6 +89,8 @@ Options:
times
- If ``raise_errors=False`` is provided with ``max_retries > 0`` and ``prevent_retries_when`` is set to a list of
exception types, failed questions are retried except for those whose exception types are in the list
- The maximum number of threads that can be used to ask questions in parallel can be set via the ``max_workers``
argument. This has no effect on the total number of questions that can be asked via ``Child.ask_multiple``.


Asking a question within a service
Expand Down
180 changes: 91 additions & 89 deletions docs/source/inter_service_compatibility.rst

Large diffs are not rendered by default.

31 changes: 0 additions & 31 deletions octue/cloud/emulators/_pub_sub.py
Expand Up @@ -388,37 +388,6 @@ def ask(
return response_subscription, question_uuid


class MockMessagePuller:
"""A mock message puller that enqueues messages in the message handler in the order they're provided on
initialisation. This is meant for patching
`octue.cloud.pub_sub.message_handler.OrderedMessageHandler._pull_and_enqueue_message` in tests.
:param iter(octue.cloud.pub_sub.emulators._pub_sub.MockMessage) messages:
:param octue.cloud.pub_sub.message_handler.OrderedMessageHandler message_handler:
:return None:
"""

def __init__(self, messages, message_handler):
self.messages = messages
self.message_handler = message_handler
self.current_message = 0

def pull(self, timeout):
"""Get the next message from the messages given at instantiation and enqueue it for handling in the message
handler.
:return None:
"""
try:
message = self.messages[self.current_message]
except IndexError:
return

message_number = int(message.attributes["message_number"])
self.message_handler.waiting_messages[message_number] = json.loads(message.data.decode())
self.current_message += 1


class MockAnalysis:
"""A mock Analysis object with just the output strands.
Expand Down
1 change: 0 additions & 1 deletion octue/cloud/emulators/child.py
Expand Up @@ -126,7 +126,6 @@ def ask(
subscription,
handle_monitor_message=handle_monitor_message,
record_messages=record_messages,
service_name=self.id,
timeout=timeout,
)

Expand Down
129 changes: 84 additions & 45 deletions octue/cloud/pub_sub/message_handler.py
Expand Up @@ -27,6 +27,10 @@
logger = logging.getLogger(__name__)


MAX_SIMULTANEOUS_MESSAGES_PULL = 50
PARENT_SDK_VERSION = importlib.metadata.version("octue")


class OrderedMessageHandler:
"""A handler for Google Pub/Sub messages received via a pull subscription that ensures messages are handled in the
order they were sent.
Expand All @@ -38,6 +42,7 @@ class OrderedMessageHandler:
:param str service_name: an arbitrary name to refer to the service subscribed to by (used for labelling its remote log messages)
:param dict|None message_handlers: a mapping of message type names to callables that handle each type of message. The handlers should not mutate the messages.
:param dict|str schema: the JSON schema (or URI of one) to validate messages against
:param int|float skip_missing_messages_after: the number of seconds after which to skip any messages if they haven't arrived but subsequent messages have
:return None:
"""

Expand All @@ -50,17 +55,17 @@ def __init__(
service_name="REMOTE",
message_handlers=None,
schema=SERVICE_COMMUNICATION_SCHEMA,
skip_missing_messages_after=10,
):
self.subscription = subscription
self.receiving_service = receiving_service
self.handle_monitor_message = handle_monitor_message
self.record_messages = record_messages
self.service_name = service_name
self.schema = schema

if isinstance(schema, str):
self.schema = {"$ref": schema}
else:
self.schema = schema
self.skip_missing_messages_after = skip_missing_messages_after
self._missing_message_detection_time = None

self.question_uuid = self.subscription.path.split(".")[-1]
self.handled_messages = []
Expand All @@ -72,7 +77,7 @@ def __init__(
self._alive = True
self._start_time = None
self._previous_message_number = -1
self._earliest_message_number_received = math.inf
self._earliest_waiting_message_number = math.inf

self._message_handlers = message_handlers or {
"delivery_acknowledgement": self._handle_delivery_acknowledgement,
Expand All @@ -93,10 +98,22 @@ def total_run_time(self):
:return float|None: the amount of time since `self.handle_messages` was called (in seconds)
"""
if self._start_time is None:
return
return None

return time.perf_counter() - self._start_time

@property
def time_since_missing_message(self):
"""Get the amount of time elapsed since the last missing message was detected. If no missing messages have been
detected or they've already been skipped past, `None` is returned.
:return float|None:
"""
if self._missing_message_detection_time is None:
return None

return time.perf_counter() - self._missing_message_detection_time

@property
def _time_since_last_heartbeat(self):
"""Get the time period since the last heartbeat was received.
Expand All @@ -108,13 +125,12 @@ def _time_since_last_heartbeat(self):

return datetime.now() - self._last_heartbeat

def handle_messages(self, timeout=60, maximum_heartbeat_interval=300, skip_first_messages_after=60):
def handle_messages(self, timeout=60, maximum_heartbeat_interval=300):
"""Pull messages and handle them in the order they were sent until a result is returned by a message handler,
then return that result.
:param float|None timeout: how long to wait for an answer before raising a `TimeoutError`
:param int|float maximum_heartbeat_interval: the maximum amount of time (in seconds) allowed between child heartbeats before an error is raised
:param int|float skip_first_messages_after: the number of seconds after which to skip the first n messages if they haven't arrived but subsequent messages have
:raise TimeoutError: if the timeout is exceeded before receiving the final message
:return dict: the first result returned by a message handler
"""
Expand All @@ -134,8 +150,8 @@ def handle_messages(self, timeout=60, maximum_heartbeat_interval=300, skip_first

while self._alive:
pull_timeout = self._check_timeout_and_get_pull_timeout(timeout)
self._pull_and_enqueue_message(timeout=pull_timeout)
result = self._attempt_to_handle_queued_messages(skip_first_messages_after)
self._pull_and_enqueue_available_messages(timeout=pull_timeout)
result = self._attempt_to_handle_waiting_messages()

if result is not None:
return result
Expand Down Expand Up @@ -186,9 +202,9 @@ def _check_timeout_and_get_pull_timeout(self, timeout):

return timeout - total_run_time

def _pull_and_enqueue_message(self, timeout):
"""Pull a message from the subscription and enqueue it in `self.waiting_messages`, raising a `TimeoutError` if
the timeout is exceeded before succeeding.
def _pull_and_enqueue_available_messages(self, timeout):
"""Pull as many messages from the subscription as are available and enqueue them in `self.waiting_messages`,
raising a `TimeoutError` if the timeout is exceeded before succeeding.
:param float|None timeout: how long to wait in seconds for the message before raising a `TimeoutError`
:raise TimeoutError|concurrent.futures.TimeoutError: if the timeout is exceeded
Expand All @@ -197,19 +213,17 @@ def _pull_and_enqueue_message(self, timeout):
pull_start_time = time.perf_counter()
attempt = 1

while True:
while self._alive:
logger.debug("Pulling messages from Google Pub/Sub: attempt %d.", attempt)

pull_response = self._subscriber.pull(
request={"subscription": self.subscription.path, "max_messages": 1},
request={"subscription": self.subscription.path, "max_messages": MAX_SIMULTANEOUS_MESSAGES_PULL},
retry=retry.Retry(),
)

try:
answer = pull_response.received_messages[0]
if len(pull_response.received_messages) > 0:
break

except IndexError:
else:
logger.debug("Google Pub/Sub pull response timed out early.")
attempt += 1

Expand All @@ -220,16 +234,35 @@ def _pull_and_enqueue_message(self, timeout):
f"No message received from topic {self.subscription.topic.path!r} after {timeout} seconds.",
)

self._subscriber.acknowledge(request={"subscription": self.subscription.path, "ack_ids": [answer.ack_id]})
logger.debug("%r received a message related to question %r.", self.receiving_service, self.question_uuid)
if not pull_response.received_messages:
return

self._subscriber.acknowledge(
request={
"subscription": self.subscription.path,
"ack_ids": [message.ack_id for message in pull_response.received_messages],
}
)

event, attributes = extract_event_and_attributes_from_pub_sub(answer.message)
for message in pull_response.received_messages:
self._extract_and_enqueue_event(message)

self._earliest_waiting_message_number = min(self.waiting_messages.keys())

def _extract_and_enqueue_event(self, message):
"""Extract an event from the Pub/Sub message and add it to `self.waiting_messages`.
:param dict message:
:return None:
"""
logger.debug("%r received a message related to question %r.", self.receiving_service, self.question_uuid)
event, attributes = extract_event_and_attributes_from_pub_sub(message.message)

if not is_event_valid(
event=event,
attributes=attributes,
receiving_service=self.receiving_service,
parent_sdk_version=importlib.metadata.version("octue"),
parent_sdk_version=PARENT_SDK_VERSION,
child_sdk_version=attributes.get("version"),
schema=self.schema,
):
Expand All @@ -241,25 +274,31 @@ def _pull_and_enqueue_message(self, timeout):

message_number = attributes["message_number"]
self.waiting_messages[message_number] = event
self._earliest_message_number_received = min(self._earliest_message_number_received, message_number)

def _attempt_to_handle_queued_messages(self, skip_first_messages_after=60):
"""Attempt to handle messages in the pulled message queue. If these messages aren't consecutive with the last
handled message (i.e. if messages have been received out of order and the next in-order message hasn't been
received yet), just return. After the given amount of time, if the first n messages haven't arrived but
subsequent ones have, skip to the earliest received message and continue from there.
def _attempt_to_handle_waiting_messages(self):
"""Attempt to handle messages waiting in `self.waiting_messages`. If these messages aren't consecutive to the
last handled message (i.e. if messages have been received out of order and the next in-order message hasn't been
received yet), just return. After the missing message wait time has passed, if this set of missing messages
haven't arrived but subsequent ones have, skip to the earliest waiting message and continue from there.
:param int|float skip_first_messages_after: the number of seconds after which to skip the first n messages if they haven't arrived but subsequent messages have
:return any|None: either a non-`None` result from a message handler or `None` if nothing was returned by the message handlers or if the next in-order message hasn't been received yet
"""
while self.waiting_messages:
try:
# If the next consecutive message has been received:
message = self.waiting_messages.pop(self._previous_message_number + 1)

# If the next consecutive message hasn't been received:
except KeyError:
# Start the missing message timer if it isn't already running.
if self._missing_message_detection_time is None:
self._missing_message_detection_time = time.perf_counter()

if self.time_since_missing_message > self.skip_missing_messages_after:
message = self._skip_to_earliest_waiting_message()

if self.total_run_time > skip_first_messages_after and self._previous_message_number == -1:
message = self._get_and_start_from_earliest_received_message(skip_first_messages_after)
# Declare there are no more missing messages.
self._missing_message_detection_time = None

if not message:
return
Expand All @@ -272,28 +311,29 @@ def _attempt_to_handle_queued_messages(self, skip_first_messages_after=60):
if result is not None:
return result

def _get_and_start_from_earliest_received_message(self, skip_first_messages_after):
"""Get the earliest received message from the waiting message queue and set the message handler up to start from
it instead of the first message sent by the child.
def _skip_to_earliest_waiting_message(self):
"""Get the earliest waiting message and set the message handler up to continue from it.
:param int|float skip_first_messages_after: the number of seconds after which to skip the first n messages if they haven't arrived but subsequent messages have
:return dict|None:
"""
try:
message = self.waiting_messages.pop(self._earliest_message_number_received)
message = self.waiting_messages.pop(self._earliest_waiting_message_number)
except KeyError:
return

self._previous_message_number = self._earliest_message_number_received - 1
number_of_missing_messages = self._earliest_waiting_message_number - self._previous_message_number - 1

# Let the message handler know it can handle the next earliest message.
self._previous_message_number = self._earliest_waiting_message_number - 1

logger.warning(
"%r: The first %d messages for question %r weren't received after %ds - skipping to the "
"earliest received message (message number %d).",
"%r: %d consecutive messages missing for question %r after %ds - skipping to next earliest waiting message "
"(message %d).",
self.receiving_service,
self._earliest_message_number_received,
number_of_missing_messages,
self.question_uuid,
skip_first_messages_after,
self._earliest_message_number_received,
self.skip_missing_messages_after,
self._earliest_waiting_message_number,
)

return message
Expand Down Expand Up @@ -356,8 +396,7 @@ def _handle_log_message(self, message):
text_colour=self._log_message_colours[0],
)

# Colour any analysis sections from children of the immediate child with the rest of the colour palette and
# colour the message from the furthest child white.
# Colour any analysis sections from children of the immediate child with the rest of the colour palette.
subchild_analysis_sections = [section.strip("[") for section in re.split("] ", record.msg)]
final_message = subchild_analysis_sections.pop(-1)

Expand Down
12 changes: 9 additions & 3 deletions octue/cloud/pub_sub/service.py
Expand Up @@ -22,6 +22,7 @@
convert_service_id_to_pub_sub_form,
create_sruid,
get_default_sruid,
get_sruid_from_pub_sub_resource_name,
raise_if_revision_not_registered,
split_service_id,
validate_sruid,
Expand All @@ -34,6 +35,10 @@


logger = logging.getLogger(__name__)

# A lock to ensure only one message can be sent at a time so that the message number is incremented correctly when
# messages are being sent on multiple threads (e.g. via the main thread and a periodic monitor message thread). This
# avoids 1) messages overwriting each other in the parent's message handler and 2) messages losing their order.
send_message_lock = threading.Lock()

DEFAULT_NAMESPACE = "default"
Expand Down Expand Up @@ -371,7 +376,6 @@ def wait_for_answer(
subscription,
handle_monitor_message=None,
record_messages=True,
service_name="REMOTE",
timeout=60,
maximum_heartbeat_interval=300,
):
Expand All @@ -381,7 +385,6 @@ def wait_for_answer(
:param octue.cloud.pub_sub.subscription.Subscription subscription: the subscription for the question's answer
:param callable|None handle_monitor_message: a function to handle monitor messages (e.g. send them to an endpoint for plotting or displaying) - this function should take a single JSON-compatible python primitive as an argument (note that this could be an array or object)
:param bool record_messages: if `True`, record messages received from the child in the `received_messages` attribute
:param str service_name: a name by which to refer to the child subscribed to (used for labelling its log messages if subscribed to)
:param float|None timeout: how long in seconds to wait for an answer before raising a `TimeoutError`
:param float|int delivery_acknowledgement_timeout: how long in seconds to wait for a delivery acknowledgement before aborting
:param float|int maximum_heartbeat_interval: the maximum amount of time (in seconds) allowed between child heartbeats before an error is raised
Expand All @@ -394,6 +397,8 @@ def wait_for_answer(
f"its push endpoint at {subscription.push_endpoint!r}."
)

service_name = get_sruid_from_pub_sub_resource_name(subscription.name)

self._message_handler = OrderedMessageHandler(
subscription=subscription,
receiving_service=self,
Expand Down Expand Up @@ -435,7 +440,8 @@ def send_exception(self, topic, question_uuid, timeout=30):
)

def _send_message(self, message, topic, attributes=None, timeout=30):
"""Send a JSON-serialised message to the given topic with optional message attributes.
"""Send a JSON-serialised message to the given topic with optional message attributes and increment the
`messages_published` attribute of the topic by one. This method is thread-safe.
:param dict message: JSON-serialisable data to send as a message
:param octue.cloud.pub_sub.topic.Topic topic: the Pub/Sub topic to send the message to
Expand Down

0 comments on commit 2e6c23f

Please sign in to comment.