diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py index f05c1a88a6..0df67f1a06 100644 --- a/google/cloud/aiplatform/tensorboard/uploader.py +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -29,6 +29,7 @@ Iterable, Optional, ContextManager, + Tuple, ) import uuid @@ -195,6 +196,7 @@ def __init__( self._logdir = logdir self._allowed_plugins = frozenset(allowed_plugins) self._run_name_prefix = run_name_prefix + self._is_brand_new_experiment = False self._upload_limits = upload_limits if not self._upload_limits: @@ -265,6 +267,9 @@ def active_filter(secs): self._logdir_loader = logdir_loader.LogdirLoader( self._logdir, directory_loader_factory ) + self._logdir_loader_pre_create = logdir_loader.LogdirLoader( + self._logdir, directory_loader_factory + ) self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity) self._create_additional_senders() @@ -290,6 +295,7 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim tensorboard_experiment=tb_experiment, tensorboard_experiment_id=self._experiment_name, ) + self._is_brand_new_experiment = True except exceptions.AlreadyExists: logger.info("Creating experiment failed. Retrieving experiment.") experiment_name = os.path.join( @@ -303,7 +309,11 @@ def create_experiment(self): experiment = self._create_or_get_experiment() self._experiment = experiment - request_sender = _BatchedRequestSender( + self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager( + self._experiment.name, self._api + ) + + self._request_sender = _BatchedRequestSender( self._experiment.name, self._api, allowed_plugins=self._allowed_plugins, @@ -313,6 +323,7 @@ def create_experiment(self): blob_rpc_rate_limiter=self._blob_rpc_rate_limiter, blob_storage_bucket=self._blob_storage_bucket, blob_storage_folder=self._blob_storage_folder, + one_platform_resource_manager=self._one_platform_resource_manager, tracker=self._tracker, ) @@ -323,7 +334,8 @@ def create_experiment(self): ) self._dispatcher = _Dispatcher( - request_sender=request_sender, additional_senders=self._additional_senders, + request_sender=self._request_sender, + additional_senders=self._additional_senders, ) def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]: @@ -366,6 +378,17 @@ def start_uploading(self): """ if self._dispatcher is None: raise RuntimeError("Must call create_experiment() before start_uploading()") + + if self._one_shot: + if self._is_brand_new_experiment: + self._pre_create_runs_and_time_series() + else: + logger.warning( + "Please consider uploading to a new experiment instead of " + "an existing one, as the former allows for better upload " + "performance." + ) + while True: self._logdir_poll_rate_limiter.tick() self._upload_once() @@ -377,6 +400,58 @@ def start_uploading(self): "without any uploadable data" % self._logdir ) + def _pre_create_runs_and_time_series(self): + """ + Iterates though the log dir to collect TensorboardRuns and + TensorboardTimeSeries that need to be created, and creates them in batch + to speed up uploading later on. + """ + self._logdir_loader_pre_create.synchronize_runs() + run_to_events = self._logdir_loader_pre_create.get_run_events() + if self._run_name_prefix: + run_to_events = { + self._run_name_prefix + k: v for k, v in run_to_events.items() + } + + run_names = [] + run_tag_name_to_time_series_proto = {} + for (run_name, events) in run_to_events.items(): + run_names.append(run_name) + for event in events: + _filter_graph_defs(event) + for value in event.summary.value: + metadata, is_valid = self._request_sender.get_metadata_and_validate( + run_name, value + ) + if not is_valid: + continue + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + value_type = ( + tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + value_type = ( + tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR + ) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + value_type = ( + tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE + ) + + run_tag_name_to_time_series_proto[ + (run_name, value.tag) + ] = tensorboard_time_series.TensorboardTimeSeries( + display_name=value.tag, + value_type=value_type, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ) + + self._one_platform_resource_manager.batch_create_runs(run_names) + self._one_platform_resource_manager.batch_create_time_series( + run_tag_name_to_time_series_proto + ) + def _upload_once(self): """Runs one upload cycle, sending zero or more RPCs.""" logger.info("Starting an upload cycle") @@ -439,6 +514,7 @@ def __init__( blob_rpc_rate_limiter: util.RateLimiter, blob_storage_bucket: storage.Bucket, blob_storage_folder: str, + one_platform_resource_manager: uploader_utils.OnePlatformResourceManager, tracker: upload_tracker.UploadTracker, ): """Constructs _BatchedRequestSender for the given experiment resource. @@ -456,6 +532,8 @@ def __init__( Note the chunk stream is internally rate-limited by backpressure from the server, so it is not a concern that we do not explicitly rate-limit within the stream here. + one_platform_resource_manager: An instance of the One Platform + resource management class. tracker: Upload tracker to track information about uploads. """ self._experiment_resource_name = experiment_resource_name @@ -463,9 +541,7 @@ def __init__( self._tag_metadata = {} self._allowed_plugins = frozenset(allowed_plugins) self._tracker = tracker - self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager( - self._experiment_resource_name, self._api - ) + self._one_platform_resource_manager = one_platform_resource_manager self._scalar_request_sender = _ScalarBatchedRequestSender( experiment_resource_id=experiment_resource_name, api=api, @@ -516,6 +592,37 @@ def send_request( RuntimeError: If no progress can be made because even a single point is too large (say, due to a gigabyte-long tag name). """ + metadata, is_valid = self.get_metadata_and_validate(run_name, value) + if not is_valid: + return + plugin_name = metadata.plugin_data.plugin_name + self._tracker.add_plugin_name(plugin_name) + + if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: + self._scalar_request_sender.add_event(run_name, event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: + self._tensor_request_sender.add_event(run_name, event, value, metadata) + elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: + self._blob_request_sender.add_event(run_name, event, value, metadata) + + def flush(self): + """Flushes any events that have been stored.""" + self._scalar_request_sender.flush() + self._tensor_request_sender.flush() + self._blob_request_sender.flush() + + def get_metadata_and_validate( + self, run_name: str, value: tf.compat.v1.Summary.Value + ) -> Tuple[tf.compat.v1.SummaryMetadata, bool]: + """ + + :param run_name: Name of the run retrieved by + `LogdirLoader.get_run_events` + :param value: A single `tf.compat.v1.Summary.Value` from the event, + where there can be multiple values per event. + :return: (metadata, is_valid): a metadata derived from the value, and + whether the value itself is valid. + """ time_series_key = (run_name, value.tag) @@ -539,7 +646,7 @@ def send_request( metadata.plugin_data.plugin_name, value.metadata.plugin_data.plugin_name, ) - return + return metadata, False if plugin_name not in self._allowed_plugins: if first_in_time_series: logger.info( @@ -547,21 +654,8 @@ def send_request( time_series_key, plugin_name, ) - return - self._tracker.add_plugin_name(plugin_name) - - if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: - self._scalar_request_sender.add_event(run_name, event, value, metadata) - elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: - self._tensor_request_sender.add_event(run_name, event, value, metadata) - elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: - self._blob_request_sender.add_event(run_name, event, value, metadata) - - def flush(self): - """Flushes any events that have been stored.""" - self._scalar_request_sender.flush() - self._tensor_request_sender.flush() - self._blob_request_sender.flush() + return metadata, False + return metadata, True class _Dispatcher(object): diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py index 55f9c03156..679eb02ef4 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_utils.py +++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py @@ -22,7 +22,7 @@ import logging import re import time -from typing import Callable, Dict, Generator, Optional +from typing import Callable, Dict, Generator, Optional, List, Tuple import uuid from tensorboard.util import tb_logging @@ -39,7 +39,6 @@ tensorboard_time_series_v1beta1 as tensorboard_time_series, ) from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1 -from google.cloud.aiplatform_v1beta1.types import TensorboardRun TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient @@ -66,6 +65,9 @@ def send_requests(run_name: str): class OnePlatformResourceManager(object): """Helper class managing One Platform resources.""" + CREATE_RUN_BATCH_SIZE = 1000 + CREATE_TIME_SERIES_BATCH_SIZE = 1000 + def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient): """Constructor for OnePlatformResourceManager. @@ -81,6 +83,96 @@ def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient) self._run_name_to_run_resource_name: Dict[str, str] = {} self._run_tag_name_to_time_series_name: Dict[(str, str), str] = {} + def batch_create_runs( + self, run_names: List[str] + ) -> List[tensorboard_run.TensorboardRun]: + """Batch creates TensorboardRuns. + + Args: + run_names: a list of run_names for creating the TensorboardRuns. + Returns: + the created TensorboardRuns + """ + batch_size = OnePlatformResourceManager.CREATE_RUN_BATCH_SIZE + created_runs = [] + for i in range(0, len(run_names), batch_size): + one_batch_run_names = run_names[i : i + batch_size] + tb_run_requests = [ + tensorboard_service.CreateTensorboardRunRequest( + parent=self._experiment_resource_name, + tensorboard_run=tensorboard_run.TensorboardRun( + display_name=run_name + ), + tensorboard_run_id=str(uuid.uuid4()), + ) + for run_name in one_batch_run_names + ] + + tb_runs = self._api.batch_create_tensorboard_runs( + parent=self._experiment_resource_name, requests=tb_run_requests, + ).tensorboard_runs + + self._run_name_to_run_resource_name.update( + {run.display_name: run.name for run in tb_runs} + ) + + created_runs.extend(tb_runs) + + return created_runs + + def batch_create_time_series( + self, + run_tag_name_to_time_series: Dict[ + Tuple[str, str], tensorboard_time_series.TensorboardTimeSeries + ], + ) -> List[tensorboard_time_series.TensorboardTimeSeries]: + """Batch creates TensorboardTimeSeries. + + Args: + run_tag_name_to_time_series: a dictionary of + (run_name, tag_name) to TensorboardTimeSeries proto, containing + the TensorboardTimeSeries to create. + Returns: + the created TensorboardTimeSeries + """ + batch_size = OnePlatformResourceManager.CREATE_TIME_SERIES_BATCH_SIZE + run_tag_name_to_time_series_entries = list(run_tag_name_to_time_series.items()) + run_resource_name_to_run_name = { + v: k for k, v in self._run_name_to_run_resource_name.items() + } + created_time_series = [] + for i in range(0, len(run_tag_name_to_time_series_entries), batch_size): + requests = [ + tensorboard_service.CreateTensorboardTimeSeriesRequest( + parent=self._run_name_to_run_resource_name[run_name], + tensorboard_time_series=time_series, + ) + for ( + (run_name, tag_name), + time_series, + ) in run_tag_name_to_time_series_entries[i : i + batch_size] + ] + + time_series = self._api.batch_create_tensorboard_time_series( + parent=self._experiment_resource_name, requests=requests, + ).tensorboard_time_series + + self._run_tag_name_to_time_series_name.update( + { + ( + run_resource_name_to_run_name[ + ts.name[: ts.name.index("/timeSeries")] + ], + ts.display_name, + ): ts.name + for ts in time_series + } + ) + + created_time_series.extend(time_series) + + return created_time_series + def get_run_resource_name(self, run_name: str) -> str: """ Get the resource name of the run if it exists, otherwise creates the run @@ -99,7 +191,9 @@ def get_run_resource_name(self, run_name: str) -> str: self._run_name_to_run_resource_name[run_name] = tb_run.name return self._run_name_to_run_resource_name[run_name] - def _create_or_get_run_resource(self, run_name: str) -> TensorboardRun: + def _create_or_get_run_resource( + self, run_name: str + ) -> tensorboard_run.TensorboardRun: """Creates a new run resource in current tensorboard experiment resource. Args: diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index cdd3ba2c51..fd071d0e2b 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -52,6 +52,7 @@ ) from google.cloud.aiplatform.compat.types import ( tensorboard_data_v1beta1 as tensorboard_data, + tensorboard_service_v1beta1 as tensorboard_service, ) from google.cloud.aiplatform.compat.types import ( tensorboard_experiment_v1beta1 as tensorboard_experiment_type, @@ -260,6 +261,10 @@ def _create_dispatcher( tensor_rpc_rate_limiter = util.RateLimiter(0) blob_rpc_rate_limiter = util.RateLimiter(0) + one_platform_resource_manager = uploader_utils.OnePlatformResourceManager( + experiment_resource_name, api + ) + request_sender = uploader_lib._BatchedRequestSender( experiment_resource_name=experiment_resource_name, api=api, @@ -270,6 +275,7 @@ def _create_dispatcher( blob_rpc_rate_limiter=blob_rpc_rate_limiter, blob_storage_bucket=None, blob_storage_folder=None, + one_platform_resource_manager=one_platform_resource_manager, tracker=upload_tracker.UploadTracker(verbosity=0), ) @@ -593,7 +599,41 @@ def test_start_uploading_scalars(self): def test_start_uploading_scalars_one_shot(self): """Check that one-shot uploading stops without AbortUploadError.""" + + def batch_create_runs(parent, requests): + # pylint: disable=unused-argument + tb_runs = [] + for request in requests: + tb_run = tensorboard_run_type.TensorboardRun(request.tensorboard_run) + tb_run.name = "{}/runs/{}".format( + request.parent, request.tensorboard_run_id + ) + tb_runs.append(tb_run) + return tensorboard_service.BatchCreateTensorboardRunsResponse( + tensorboard_runs=tb_runs + ) + + def batch_create_time_series(parent, requests): + # pylint: disable=unused-argument + tb_time_series = [] + for request in requests: + ts = tensorboard_time_series_type.TensorboardTimeSeries( + request.tensorboard_time_series + ) + ts.name = "{}/timeSeries/{}".format( + request.parent, request.tensorboard_time_series.display_name + ) + tb_time_series.append(ts) + return tensorboard_service.BatchCreateTensorboardTimeSeriesResponse( + tensorboard_time_series=tb_time_series + ) + mock_client = _create_mock_client() + mock_client.batch_create_tensorboard_runs.side_effect = batch_create_runs + mock_client.batch_create_tensorboard_time_series.side_effect = ( + batch_create_time_series + ) + mock_rate_limiter = mock.create_autospec(util.RateLimiter) mock_tracker = mock.MagicMock() with mock.patch.object( @@ -614,17 +654,32 @@ def test_start_uploading_scalars_one_shot(self): mock_logdir_loader.get_run_events.side_effect = [ { "run 1": _apply_compat( - [_scalar_event("1.1", 5.0), _scalar_event("1.2", 5.0)] + [_scalar_event("tag_1.1", 5.0), _scalar_event("tag_1.2", 5.0)] ), "run 2": _apply_compat( - [_scalar_event("2.1", 5.0), _scalar_event("2.2", 5.0)] + [_scalar_event("tag_2.1", 5.0), _scalar_event("tag_2.2", 5.0)] + ), + }, + # Note the lack of AbortUploadError here. + ] + mock_logdir_loader_pre_create = mock.create_autospec(logdir_loader.LogdirLoader) + mock_logdir_loader_pre_create.get_run_events.side_effect = [ + { + "run 1": _apply_compat( + [_scalar_event("tag_1.1", 5.0), _scalar_event("tag_1.2", 5.0)] + ), + "run 2": _apply_compat( + [_scalar_event("tag_2.1", 5.0), _scalar_event("tag_2.2", 5.0)] ), }, # Note the lack of AbortUploadError here. ] with mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader): - uploader.start_uploading() + with mock.patch.object( + uploader, "_logdir_loader_pre_create", mock_logdir_loader_pre_create + ): + uploader.start_uploading() self.assertEqual(2, mock_client.write_tensorboard_experiment_data.call_count) self.assertEqual(2, mock_rate_limiter.tick.call_count)