diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py index e74d119f4a..d3d9ef7db1 100644 --- a/google/cloud/aiplatform/tensorboard/uploader.py +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -30,7 +30,6 @@ Generator, Iterable, Optional, - Tuple, ContextManager, ) import uuid @@ -118,6 +117,15 @@ logger.setLevel(logging.WARNING) +class RequestSender(object): + """A base class for additional request sender objects. + + Currently just used for typing. + """ + + pass + + class TensorBoardUploader(object): """Uploads a TensorBoard logdir to TensorBoard.gcp.""" @@ -216,7 +224,7 @@ def __init__( self._description = description self._verbosity = verbosity self._one_shot = one_shot - self._request_sender = None + self._dispatcher = None if logdir_poll_rate_limiter is None: self._logdir_poll_rate_limiter = util.RateLimiter( _MIN_LOGDIR_POLL_INTERVAL_SECS @@ -296,7 +304,7 @@ def create_experiment(self): experiment = self._create_or_get_experiment() self._experiment = experiment - self._request_sender = _BatchedRequestSender( + request_sender = _BatchedRequestSender( self._experiment.name, self._api, allowed_plugins=self._allowed_plugins, @@ -309,6 +317,26 @@ def create_experiment(self): tracker=self._tracker, ) + additional_senders = self._create_additional_senders() + + self._dispatcher = _Dispatcher( + request_sender=request_sender, additional_senders=additional_senders, + ) + + def _create_additional_senders(self) -> Dict[str, RequestSender]: + """Create any additional senders for non traditional event files. + + Some items that are used for plugins do not process typical event files, + but need to be searched for and stored so that they can be used by the + plugin. If there are any items that cannot be searched for via the + `_BatchedRequestSender`, add them here. + + Returns: + Mapping from plugin name to Sender. + """ + additional_senders = {} + return additional_senders + def get_experiment_resource_name(self): return self._experiment.name @@ -320,7 +348,7 @@ def start_uploading(self): ExperimentNotFoundError: If the experiment is deleted during the course of the upload. """ - if self._request_sender is None: + if self._dispatcher is None: raise RuntimeError("Must call create_experiment() before start_uploading()") while True: self._logdir_poll_rate_limiter.tick() @@ -348,7 +376,7 @@ def _upload_once(self): self._run_name_prefix + k: v for k, v in run_to_events.items() } with self._tracker.send_tracker(): - self._request_sender.send_requests(run_to_events) + self._dispatcher.dispatch_requests(run_to_events) class ExperimentNotFoundError(RuntimeError): @@ -453,8 +481,11 @@ def __init__( tracker=self._tracker, ) - def send_requests( - self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] + def send_request( + self, + run_name: str, + event: tf.compat.v1.Event, + value: tf.compat.v1.Summary.Value, ): """Accepts a stream of TF events and sends batched write RPCs. @@ -462,78 +493,77 @@ def send_requests( the type of data (Scalar vs Tensor vs Blob) being sent. Args: - run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` - values, as returned by `LogdirLoader.get_run_events`. + run_name: Name of the run retrieved by `LogdirLoader.get_run_events` + event: The `tf.compat.v1.Event` for the run + value: A single `tf.compat.v1.Summary.Value` from the event, where + there can be multiple values per event. Raises: RuntimeError: If no progress can be made because even a single point is too large (say, due to a gigabyte-long tag name). """ - for (run_name, event, value) in self._run_values(run_to_events): - time_series_key = (run_name, value.tag) - - # The metadata for a time series is memorized on the first event. - # If later events arrive with a mismatching plugin_name, they are - # ignored with a warning. - metadata = self._tag_metadata.get(time_series_key) - first_in_time_series = False - if metadata is None: - first_in_time_series = True - metadata = value.metadata - self._tag_metadata[time_series_key] = metadata - - plugin_name = metadata.plugin_data.plugin_name - if value.HasField("metadata") and ( - plugin_name != value.metadata.plugin_data.plugin_name - ): - logger.warning( - "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key = (run_name, value.tag) + + # The metadata for a time series is memorized on the first event. + # If later events arrive with a mismatching plugin_name, they are + # ignored with a warning. + metadata = self._tag_metadata.get(time_series_key) + first_in_time_series = False + if metadata is None: + first_in_time_series = True + metadata = value.metadata + self._tag_metadata[time_series_key] = metadata + + plugin_name = metadata.plugin_data.plugin_name + if value.HasField("metadata") and ( + plugin_name != value.metadata.plugin_data.plugin_name + ): + logger.warning( + "Mismatching plugin names for %s. Expected %s, found %s.", + time_series_key, + metadata.plugin_data.plugin_name, + value.metadata.plugin_data.plugin_name, + ) + return + if plugin_name not in self._allowed_plugins: + if first_in_time_series: + logger.info( + "Skipping time series %r with unsupported plugin name %r", time_series_key, - metadata.plugin_data.plugin_name, - value.metadata.plugin_data.plugin_name, - ) - continue - if plugin_name not in self._allowed_plugins: - if first_in_time_series: - logger.info( - "Skipping time series %r with unsupported plugin name %r", - time_series_key, - plugin_name, - ) - continue - self._tracker.add_plugin_name(plugin_name) - # If this is the first time we've seen this run create a new run resource - # and an associated request sender. - if run_name not in self._run_to_run_resource: - self._create_or_get_run_resource(run_name) - self._run_to_request_sender[ - run_name - ] = self._scalar_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_tensor_request_sender[ - run_name - ] = self._tensor_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_blob_request_sender[ - run_name - ] = self._blob_request_sender_factory( - self._run_to_run_resource[run_name].name + plugin_name, ) + return + self._tracker.add_plugin_name(plugin_name) + # If this is the first time we've seen this run create a new run resource + # and an associated request sender. + if run_name not in self._run_to_run_resource: + self._create_or_get_run_resource(run_name) + self._run_to_request_sender[run_name] = self._scalar_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_tensor_request_sender[ + run_name + ] = self._tensor_request_sender_factory( + self._run_to_run_resource[run_name].name + ) + self._run_to_blob_request_sender[ + run_name + ] = self._blob_request_sender_factory( + self._run_to_run_resource[run_name].name + ) - if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: - self._run_to_request_sender[run_name].add_event(event, value, metadata) - elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: - self._run_to_tensor_request_sender[run_name].add_event( - event, value, metadata - ) - elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: - self._run_to_blob_request_sender[run_name].add_event( - event, value, metadata - ) + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + self._run_to_request_sender[run_name].add_event(event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + self._run_to_tensor_request_sender[run_name].add_event( + event, value, metadata + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + self._run_to_blob_request_sender[run_name].add_event(event, value, metadata) + def flush(self): + """Flushes any events that have been stored.""" for scalar_request_sender in self._run_to_request_sender.values(): scalar_request_sender.flush() @@ -577,12 +607,52 @@ def _create_or_get_run_resource(self, run_name: str): self._run_to_run_resource[run_name] = tb_run - def _run_values( + +class _Dispatcher(object): + """Dispatch the requests to the correct request senders.""" + + def __init__( + self, + request_sender: _BatchedRequestSender, + additional_senders: Optional[Dict[str, RequestSender]] = None, + ): + """Construct a _Dispatcher object for the TensorboardUploader. + + Args: + request_sender: A `_BatchedRequestSender` for handling events. + additional_senders: A dictionary mapping a plugin name to additional + Senders. + """ + self._request_sender = request_sender + + if not additional_senders: + additional_senders = {} + self._additional_senders = additional_senders + + def _dispatch_additional_senders( + self, run_name: str, + ): + """Dispatch events to any additional senders. + + These senders process non traditional event files for a specific plugin + and use a send_request function to process events. + + Args: + run_name: String of current training run + """ + for key, sender in self._additional_senders.items(): + sender.send_request(run_name) + + def dispatch_requests( self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]] - ) -> Generator[ - Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None - ]: - """Helper generator to create a single stream of work items. + ): + """Routes events to the appropriate sender. + + Takes a mapping from strings to an event generator. The function routes + any events that should be handled by the `_BatchedRequestSender` and + non-traditional events that need to be handled differently, which are + stored as "_additional_senders". The `_request_sender` is then flushed + after all events are added. Note that `dataclass_compat` may emit multiple variants of the same event, for backwards compatibility. Thus this stream should @@ -598,20 +668,14 @@ def _run_values( Args: run_to_events: Mapping from run name to generator of `tf.compat.v1.Event` values, as returned by `LogdirLoader.get_run_events`. - - Yields: - Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per - value. """ - # Note that this join in principle has deletion anomalies: if the input - # stream contains runs with no events, or events with no values, we'll - # lose that information. This is not a problem: we would need to prune - # such data from the request anyway. for (run_name, events) in run_to_events.items(): + self._dispatch_additional_senders(run_name) for event in events: _filter_graph_defs(event) for value in event.summary.value: - yield (run_name, event, value) + self._request_sender.send_request(run_name, event, value) + self._request_sender.flush() class _TimeSeriesResourceManager(object): diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index 56385274c3..fe198d8cde 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -213,8 +213,11 @@ def _create_uploader( ) -def _create_request_sender( - experiment_resource_name, api=None, allowed_plugins=_USE_DEFAULT +def _create_dispatcher( + experiment_resource_name, + api=None, + allowed_plugins=_USE_DEFAULT, + additional_senders={}, ): if api is _USE_DEFAULT: api = _create_mock_client() @@ -232,7 +235,7 @@ def _create_request_sender( tensor_rpc_rate_limiter = util.RateLimiter(0) blob_rpc_rate_limiter = util.RateLimiter(0) - return uploader_lib._BatchedRequestSender( + request_sender = uploader_lib._BatchedRequestSender( experiment_resource_name=experiment_resource_name, api=api, allowed_plugins=allowed_plugins, @@ -245,6 +248,10 @@ def _create_request_sender( tracker=upload_tracker.UploadTracker(verbosity=0), ) + return uploader_lib._Dispatcher( + request_sender=request_sender, additional_senders=additional_senders, + ) + def _create_scalar_request_sender( run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT @@ -914,12 +921,12 @@ def _populate_run_from_events( self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT ): mock_client = _create_mock_client() - builder = _create_request_sender( + builder = _create_dispatcher( experiment_resource_name="123", api=mock_client, allowed_plugins=allowed_plugins, ) - builder.send_requests({"": _apply_compat(events)}) + builder.dispatch_requests({"": _apply_compat(events)}) scalar_requests = mock_client.write_tensorboard_run_data.call_args_list if scalar_requests: self.assertLen(scalar_requests, 1)