From d20b520ea936a6554a24099beb0e044f237ff741 Mon Sep 17 00:00:00 2001 From: mkovalski Date: Thu, 2 Sep 2021 13:38:47 -0400 Subject: [PATCH] feat: Update tensorboard uploader to use Dispatcher for handling different event types (#651) Refactoring for tensorboard uploader so that it can use additional senders, which do not process typical event files, in the future. This is some initial work before adding the tensorboard profiler event sender, and this should have no impact to current functionality. Fixes #519 --- .../cloud/aiplatform/tensorboard/uploader.py | 230 +++++++++++------- tests/unit/aiplatform/test_uploader.py | 17 +- 2 files changed, 159 insertions(+), 88 deletions(-) diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py index e74d119f4a..d3d9ef7db1 100644 --- a/google/cloud/aiplatform/tensorboard/uploader.py +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -30,7 +30,6 @@ Generator, Iterable, Optional, - Tuple, ContextManager, ) import uuid @@ -118,6 +117,15 @@ logger.setLevel(logging.WARNING) +class RequestSender(object): + """A base class for additional request sender objects. + + Currently just used for typing. + """ + + pass + + class TensorBoardUploader(object): """Uploads a TensorBoard logdir to TensorBoard.gcp.""" @@ -216,7 +224,7 @@ def __init__( self._description = description self._verbosity = verbosity self._one_shot = one_shot - self._request_sender = None + self._dispatcher = None if logdir_poll_rate_limiter is None: self._logdir_poll_rate_limiter = util.RateLimiter( _MIN_LOGDIR_POLL_INTERVAL_SECS @@ -296,7 +304,7 @@ def create_experiment(self): experiment = self._create_or_get_experiment() self._experiment = experiment - self._request_sender = _BatchedRequestSender( + request_sender = _BatchedRequestSender( self._experiment.name, self._api, allowed_plugins=self._allowed_plugins, @@ -309,6 +317,26 @@ def create_experiment(self): tracker=self._tracker, ) + additional_senders = self._create_additional_senders() + + self._dispatcher = _Dispatcher( + request_sender=request_sender, additional_senders=additional_senders, + ) + + def _create_additional_senders(self) -> Dict[str, RequestSender]: + """Create any additional senders for non traditional event files. + + Some items that are used for plugins do not process typical event files, + but need to be searched for and stored so that they can be used by the + plugin. If there are any items that cannot be searched for via the + `_BatchedRequestSender`, add them here. + + Returns: + Mapping from plugin name to Sender. + """ + additional_senders = {} + return additional_senders + def get_experiment_resource_name(self): return self._experiment.name @@ -320,7 +348,7 @@ def start_uploading(self): ExperimentNotFoundError: If the experiment is deleted during the course of the upload. """ - if self._request_sender is None: + if self._dispatcher is None: raise RuntimeError("Must call create_experiment() before start_uploading()") while True: self._logdir_poll_rate_limiter.tick() @@ -348,7 +376,7 @@ def _upload_once(self): self._run_name_prefix + k: v for k, v in run_to_events.items() } with self._tracker.send_tracker(): - self._request_sender.send_requests(run_to_events) + self._dispatcher.dispatch_requests(run_to_events) class ExperimentNotFoundError(RuntimeError): @@ -453,8 +481,11 @@ def __init__( tracker=self._tracker, ) - def send_requests( - self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + def send_request( + self, + run_name: str, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, ): """Accepts a stream of TF events and sends batched write RPCs. @@ -462,78 +493,77 @@ def send_requests( the type of data (Scalar vs Tensor vs Blob) being sent. Args: - run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` - values, as returned by `LogdirLoader.get_run_events`. + run_name: Name of the run retrieved by `LogdirLoader.get_run_events` + event: The `tf.compat.v1.Event` for the run + value: A single `tf.compat.v1.Summary.Value` from the event, where + there can be multiple values per event. Raises: RuntimeError: If no progress can be made because even a single point is too large (say, due to a gigabyte-long tag name). """ - for (run_name, event, value) in self._run_values(run_to_events): - time_series_key = (run_name, value.tag) - - # The metadata for a time series is memorized on the first event. - # If later events arrive with a mismatching plugin_name, they are - # ignored with a warning. - metadata = self._tag_metadata.get(time_series_key) - first_in_time_series = False - if metadata is None: - first_in_time_series = True - metadata = value.metadata - self._tag_metadata[time_series_key] = metadata - - plugin_name = metadata.plugin_data.plugin_name - if value.HasField("metadata") and ( - plugin_name != value.metadata.plugin_data.plugin_name - ): - logger.warning( - "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key = (run_name, value.tag) + + # The metadata for a time series is memorized on the first event. + # If later events arrive with a mismatching plugin_name, they are + # ignored with a warning. + metadata = self._tag_metadata.get(time_series_key) + first_in_time_series = False + if metadata is None: + first_in_time_series = True + metadata = value.metadata + self._tag_metadata[time_series_key] = metadata + + plugin_name = metadata.plugin_data.plugin_name + if value.HasField("metadata") and ( + plugin_name != value.metadata.plugin_data.plugin_name + ): + logger.warning( + "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key, + metadata.plugin_data.plugin_name, + value.metadata.plugin_data.plugin_name, + ) + return + if plugin_name not in self._allowed_plugins: + if first_in_time_series: + logger.info( + "Skipping time series %r with unsupported plugin name %r", time_series_key, - metadata.plugin_data.plugin_name, - value.metadata.plugin_data.plugin_name, - ) - continue - if plugin_name not in self._allowed_plugins: - if first_in_time_series: - logger.info( - "Skipping time series %r with unsupported plugin name %r", - time_series_key, - plugin_name, - ) - continue - self._tracker.add_plugin_name(plugin_name) - # If this is the first time we've seen this run create a new run resource - # and an associated request sender. - if run_name not in self._run_to_run_resource: - self._create_or_get_run_resource(run_name) - self._run_to_request_sender[ - run_name - ] = self._scalar_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_tensor_request_sender[ - run_name - ] = self._tensor_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_blob_request_sender[ - run_name - ] = self._blob_request_sender_factory( - self._run_to_run_resource[run_name].name + plugin_name, ) + return + self._tracker.add_plugin_name(plugin_name) + # If this is the first time we've seen this run create a new run resource + # and an associated request sender. + if run_name not in self._run_to_run_resource: + self._create_or_get_run_resource(run_name) + self._run_to_request_sender[run_name] = self._scalar_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_tensor_request_sender[ + run_name + ] = self._tensor_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_blob_request_sender[ + run_name + ] = self._blob_request_sender_factory( + self._run_to_run_resource[run_name].name + ) - if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: - self._run_to_request_sender[run_name].add_event(event, value, metadata) - elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: - self._run_to_tensor_request_sender[run_name].add_event( - event, value, metadata - ) - elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: - self._run_to_blob_request_sender[run_name].add_event( - event, value, metadata - ) + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + self._run_to_request_sender[run_name].add_event(event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + self._run_to_tensor_request_sender[run_name].add_event( + event, value, metadata + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + self._run_to_blob_request_sender[run_name].add_event(event, value, metadata) + def flush(self): + """Flushes any events that have been stored.""" for scalar_request_sender in self._run_to_request_sender.values(): scalar_request_sender.flush() @@ -577,12 +607,52 @@ def _create_or_get_run_resource(self, run_name: str): self._run_to_run_resource[run_name] = tb_run - def _run_values( + +class _Dispatcher(object): + """Dispatch the requests to the correct request senders.""" + + def __init__( + self, + request_sender: _BatchedRequestSender, + additional_senders: Optional[Dict[str, RequestSender]] = None, + ): + """Construct a _Dispatcher object for the TensorboardUploader. + + Args: + request_sender: A `_BatchedRequestSender` for handling events. + additional_senders: A dictionary mapping a plugin name to additional + Senders. + """ + self._request_sender = request_sender + + if not additional_senders: + additional_senders = {} + self._additional_senders = additional_senders + + def _dispatch_additional_senders( + self, run_name: str, + ): + """Dispatch events to any additional senders. + + These senders process non traditional event files for a specific plugin + and use a send_request function to process events. + + Args: + run_name: String of current training run + """ + for key, sender in self._additional_senders.items(): + sender.send_request(run_name) + + def dispatch_requests( self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] - ) -> Generator[ - Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None - ]: - """Helper generator to create a single stream of work items. + ): + """Routes events to the appropriate sender. + + Takes a mapping from strings to an event generator. The function routes + any events that should be handled by the `_BatchedRequestSender` and + non-traditional events that need to be handled differently, which are + stored as "_additional_senders". The `_request_sender` is then flushed + after all events are added. Note that `dataclass_compat` may emit multiple variants of the same event, for backwards compatibility. Thus this stream should @@ -598,20 +668,14 @@ def _run_values( Args: run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` values, as returned by `LogdirLoader.get_run_events`. - - Yields: - Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per - value. """ - # Note that this join in principle has deletion anomalies: if the input - # stream contains runs with no events, or events with no values, we'll - # lose that information. This is not a problem: we would need to prune - # such data from the request anyway. for (run_name, events) in run_to_events.items(): + self._dispatch_additional_senders(run_name) for event in events: _filter_graph_defs(event) for value in event.summary.value: - yield (run_name, event, value) + self._request_sender.send_request(run_name, event, value) + self._request_sender.flush() class _TimeSeriesResourceManager(object): diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index 56385274c3..fe198d8cde 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -213,8 +213,11 @@ def _create_uploader( ) -def _create_request_sender( - experiment_resource_name, api=None, allowed_plugins=_USE_DEFAULT +def _create_dispatcher( + experiment_resource_name, + api=None, + allowed_plugins=_USE_DEFAULT, + additional_senders={}, ): if api is _USE_DEFAULT: api = _create_mock_client() @@ -232,7 +235,7 @@ def _create_request_sender( tensor_rpc_rate_limiter = util.RateLimiter(0) blob_rpc_rate_limiter = util.RateLimiter(0) - return uploader_lib._BatchedRequestSender( + request_sender = uploader_lib._BatchedRequestSender( experiment_resource_name=experiment_resource_name, api=api, allowed_plugins=allowed_plugins, @@ -245,6 +248,10 @@ def _create_request_sender( tracker=upload_tracker.UploadTracker(verbosity=0), ) + return uploader_lib._Dispatcher( + request_sender=request_sender, additional_senders=additional_senders, + ) + def _create_scalar_request_sender( run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT @@ -914,12 +921,12 @@ def _populate_run_from_events( self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT ): mock_client = _create_mock_client() - builder = _create_request_sender( + builder = _create_dispatcher( experiment_resource_name="123", api=mock_client, allowed_plugins=allowed_plugins, ) - builder.send_requests({"": _apply_compat(events)}) + builder.dispatch_requests({"": _apply_compat(events)}) scalar_requests = mock_client.write_tensorboard_run_data.call_args_list if scalar_requests: self.assertLen(scalar_requests, 1)