diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py index d3d9ef7db1..b422192492 100644 --- a/google/cloud/aiplatform/tensorboard/uploader.py +++ b/google/cloud/aiplatform/tensorboard/uploader.py @@ -16,6 +16,7 @@ # """Uploads a TensorBoard logdir to TensorBoard.gcp.""" import abc +from collections import defaultdict import contextlib import functools import json @@ -93,7 +94,7 @@ _DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 10 # Default maximum WriteTensorbordRunData request size in bytes. -_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 24 * (2 ** 10) # 24KiB +_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 128 * (2 ** 10) # 128KiB # Default minimum interval between initiating WriteTensorbordRunData RPCs in # milliseconds. @@ -106,7 +107,7 @@ # Default maximum WriteTensorbordRunData request size in bytes. _DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2 ** 10) # 512KiB -_DEFAULT_MAX_BLOB_REQUEST_SIZE = 24 * (2 ** 10) # 24KiB +_DEFAULT_MAX_BLOB_REQUEST_SIZE = 128 * (2 ** 10) # 24KiB # Default maximum tensor point size in bytes. _DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2 ** 10) # 16KiB @@ -451,27 +452,28 @@ def __init__( self._tag_metadata = {} self._allowed_plugins = frozenset(allowed_plugins) self._tracker = tracker - self._run_to_request_sender: Dict[str, _ScalarBatchedRequestSender] = {} - self._run_to_tensor_request_sender: Dict[str, _TensorBatchedRequestSender] = {} - self._run_to_blob_request_sender: Dict[str, _BlobRequestSender] = {} - self._run_to_run_resource: Dict[str, tensorboard_run.TensorboardRun] = {} - self._scalar_request_sender_factory = functools.partial( - _ScalarBatchedRequestSender, + self._one_platform_resource_manager = _OnePlatformResourceManager( + self._experiment_resource_name, self._api + ) + self._scalar_request_sender = _ScalarBatchedRequestSender( + experiment_resource_id=experiment_resource_name, api=api, rpc_rate_limiter=rpc_rate_limiter, max_request_size=upload_limits.max_scalar_request_size, tracker=self._tracker, + one_platform_resource_manager=self._one_platform_resource_manager, ) - self._tensor_request_sender_factory = functools.partial( - _TensorBatchedRequestSender, + self._tensor_request_sender = _TensorBatchedRequestSender( + experiment_resource_id=experiment_resource_name, api=api, rpc_rate_limiter=tensor_rpc_rate_limiter, max_request_size=upload_limits.max_tensor_request_size, max_tensor_point_size=upload_limits.max_tensor_point_size, tracker=self._tracker, + one_platform_resource_manager=self._one_platform_resource_manager, ) - self._blob_request_sender_factory = functools.partial( - _BlobRequestSender, + self._blob_request_sender = _BlobRequestSender( + experiment_resource_id=experiment_resource_name, api=api, rpc_rate_limiter=blob_rpc_rate_limiter, max_blob_request_size=upload_limits.max_blob_request_size, @@ -479,6 +481,7 @@ def __init__( blob_storage_bucket=blob_storage_bucket, blob_storage_folder=blob_storage_folder, tracker=self._tracker, + one_platform_resource_manager=self._one_platform_resource_manager, ) def send_request( @@ -535,77 +538,19 @@ def send_request( ) return self._tracker.add_plugin_name(plugin_name) - # If this is the first time we've seen this run create a new run resource - # and an associated request sender. - if run_name not in self._run_to_run_resource: - self._create_or_get_run_resource(run_name) - self._run_to_request_sender[run_name] = self._scalar_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_tensor_request_sender[ - run_name - ] = self._tensor_request_sender_factory( - self._run_to_run_resource[run_name].name - ) - self._run_to_blob_request_sender[ - run_name - ] = self._blob_request_sender_factory( - self._run_to_run_resource[run_name].name - ) if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR: - self._run_to_request_sender[run_name].add_event(event, value, metadata) + self._scalar_request_sender.add_event(run_name, event, value, metadata) elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR: - self._run_to_tensor_request_sender[run_name].add_event( - event, value, metadata - ) + self._tensor_request_sender.add_event(run_name, event, value, metadata) elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE: - self._run_to_blob_request_sender[run_name].add_event(event, value, metadata) + self._blob_request_sender.add_event(run_name, event, value, metadata) def flush(self): """Flushes any events that have been stored.""" - for scalar_request_sender in self._run_to_request_sender.values(): - scalar_request_sender.flush() - - for tensor_request_sender in self._run_to_tensor_request_sender.values(): - tensor_request_sender.flush() - - for blob_request_sender in self._run_to_blob_request_sender.values(): - blob_request_sender.flush() - - def _create_or_get_run_resource(self, run_name: str): - """Creates a new Run Resource in current Tensorboard Experiment resource. - - Args: - run_name: The display name of this run. - """ - tb_run = tensorboard_run.TensorboardRun() - tb_run.display_name = run_name - try: - tb_run = self._api.create_tensorboard_run( - parent=self._experiment_resource_name, - tensorboard_run=tb_run, - tensorboard_run_id=str(uuid.uuid4()), - ) - except exceptions.InvalidArgument as e: - # If the run name already exists then retrieve it - if "already exist" in e.message: - runs_pages = self._api.list_tensorboard_runs( - parent=self._experiment_resource_name - ) - for tb_run in runs_pages: - if tb_run.display_name == run_name: - break - - if tb_run.display_name != run_name: - raise ExistingResourceNotFoundError( - "Run with name %s already exists but is not resource list." - % run_name - ) - else: - raise - - self._run_to_run_resource[run_name] = tb_run + self._scalar_request_sender.flush() + self._tensor_request_sender.flush() + self._blob_request_sender.flush() class _Dispatcher(object): @@ -678,25 +623,98 @@ def dispatch_requests( self._request_sender.flush() -class _TimeSeriesResourceManager(object): - """Helper class managing Time Series resources.""" +class _OnePlatformResourceManager(object): + """Helper class managing One Platform resources.""" - def __init__(self, run_resource_id: str, api: TensorboardServiceClient): - """Constructor for _TimeSeriesResourceManager. + def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient): + """Constructor for _OnePlatformResourceManager. Args: - run_resource_id: The resource id for the run with the following format - projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + experiment_resource_name: The resource id for the run with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} api: TensorboardServiceStub """ - self._run_resource_id = run_resource_id + self._experiment_resource_name = experiment_resource_name self._api = api - self._tag_to_time_series_proto: Dict[ - str, tensorboard_time_series.TensorboardTimeSeries - ] = {} + self._run_name_to_run_resource_name: Dict[str, str] = {} + self._run_tag_name_to_time_series_name: Dict[(str, str), str] = {} + + def get_run_resource_name(self, run_name: str): + """ + Get the resource name of the run if it exists, otherwise creates the run + on One Platform before returning its resource name. + :param run_name: name of the run + :return: resource name of the run + """ + if run_name not in self._run_name_to_run_resource_name: + tb_run = self._create_or_get_run_resource(run_name) + self._run_name_to_run_resource_name[run_name] = tb_run.name + return self._run_name_to_run_resource_name[run_name] + + def _create_or_get_run_resource(self, run_name: str): + """Creates a new Run Resource in current Tensorboard Experiment resource. + Args: + run_name: The display name of this run. + """ + tb_run = tensorboard_run.TensorboardRun() + tb_run.display_name = run_name + try: + tb_run = self._api.create_tensorboard_run( + parent=self._experiment_resource_name, + tensorboard_run=tb_run, + tensorboard_run_id=str(uuid.uuid4()), + ) + except exceptions.InvalidArgument as e: + # If the run name already exists then retrieve it + if "already exist" in e.message: + runs_pages = self._api.list_tensorboard_runs( + parent=self._experiment_resource_name + ) + for tb_run in runs_pages: + if tb_run.display_name == run_name: + break - def get_or_create( + if tb_run.display_name != run_name: + raise ExistingResourceNotFoundError( + "Run with name %s already exists but is not resource list." + % run_name + ) + else: + raise + return tb_run + + def get_time_series_resource_name( self, + run_name: str, + tag_name: str, + time_series_resource_creator: Callable[ + [], tensorboard_time_series.TensorboardTimeSeries + ], + ): + """ + Get the resource name of the time series corresponding to the tag, if it + exists, otherwise creates the time series on One Platform before + returning its resource name. + :param run_name: name of the run + :param tag_name: name of the tag + :param time_series_resource_creator: a constructor used for creating the + time series on One Platform. + :return: resource name of the time series + """ + if (run_name, tag_name) not in self._run_tag_name_to_time_series_name: + time_series = self._create_or_get_time_series( + self.get_run_resource_name(run_name), + tag_name, + time_series_resource_creator, + ) + self._run_tag_name_to_time_series_name[ + (run_name, tag_name) + ] = time_series.name + return self._run_tag_name_to_time_series_name[(run_name, tag_name)] + + def _create_or_get_time_series( + self, + run_resource_name: str, tag_name: str, time_series_resource_creator: Callable[ [], tensorboard_time_series.TensorboardTimeSeries @@ -711,21 +729,18 @@ def get_or_create( time_series_resource_creator: A callable that produces a TimeSeries for creation. """ - if tag_name in self._tag_to_time_series_proto: - return self._tag_to_time_series_proto[tag_name] - time_series = time_series_resource_creator() time_series.display_name = tag_name try: time_series = self._api.create_tensorboard_time_series( - parent=self._run_resource_id, tensorboard_time_series=time_series + parent=run_resource_name, tensorboard_time_series=time_series ) except exceptions.InvalidArgument as e: # If the time series display name already exists then retrieve it if "already exist" in e.message: list_of_time_series = self._api.list_tensorboard_time_series( request=tensorboard_service.ListTensorboardTimeSeriesRequest( - parent=self._run_resource_id, + parent=run_resource_name, filter="display_name = {}".format(json.dumps(str(tag_name))), ) ) @@ -742,8 +757,6 @@ def get_or_create( ) else: raise - - self._tag_to_time_series_proto[tag_name] = time_series return time_series @@ -760,47 +773,49 @@ class _BaseBatchedRequestSender(object): def __init__( self, - run_resource_id: str, + experiment_resource_id: str, api: TensorboardServiceClient, rpc_rate_limiter: util.RateLimiter, max_request_size: int, tracker: upload_tracker.UploadTracker, + one_platform_resource_manager: _OnePlatformResourceManager, ): """Constructor for _BaseBatchedRequestSender. Args: - run_resource_id: The resource id for the run with the following format - projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + experiment_resource_id: The resource id for the experiment with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} api: TensorboardServiceStub rpc_rate_limiter: until.RateLimiter to limit rate of this request sender max_request_size: max number of bytes to send tracker: """ - self._run_resource_id = run_resource_id + self._experiment_resource_id = experiment_resource_id self._api = api self._rpc_rate_limiter = rpc_rate_limiter self._byte_budget_manager = _ByteBudgetManager(max_request_size) self._tracker = tracker + self._one_platform_resource_manager = one_platform_resource_manager # cache: map from Tensorboard tag to TimeSeriesData # cleared whenever a new request is created - self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {} - - self._time_series_resource_manager = _TimeSeriesResourceManager( - self._run_resource_id, self._api - ) + self._run_to_tag_to_time_series_data: Dict[ + str, Dict[str, tensorboard_data.TimeSeriesData] + ] = defaultdict(defaultdict) self._new_request() def _new_request(self): """Allocates a new request and refreshes the budget.""" - self._request = tensorboard_service.WriteTensorboardRunDataRequest() - self._tag_to_time_series_data.clear() + self._request = tensorboard_service.WriteTensorboardExperimentDataRequest( + tensorboard_experiment=self._experiment_resource_id + ) + self._run_to_tag_to_time_series_data.clear() self._num_values = 0 - self._request.tensorboard_run = self._run_resource_id self._byte_budget_manager.reset(self._request) def add_event( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, @@ -817,27 +832,32 @@ def add_event( metadata: SummaryMetadata of the event. """ try: - self._add_event_internal(event, value, metadata) + self._add_event_internal(run_name, event, value, metadata) except _OutOfSpaceError: self.flush() # Try again. This attempt should never produce OutOfSpaceError # because we just flushed. try: - self._add_event_internal(event, value, metadata) + self._add_event_internal(run_name, event, value, metadata) except _OutOfSpaceError: raise RuntimeError("add_event failed despite flush") def _add_event_internal( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, ): self._num_values += 1 - time_series_data_proto = self._tag_to_time_series_data.get(value.tag) + time_series_data_proto = self._run_to_tag_to_time_series_data[run_name].get( + value.tag + ) if time_series_data_proto is None: - time_series_data_proto = self._create_time_series_data(value.tag, metadata) - self._create_point(time_series_data_proto, event, value, metadata) + time_series_data_proto = self._create_time_series_data( + run_name, value.tag, metadata + ) + self._create_point(run_name, time_series_data_proto, event, value, metadata) def flush(self): """Sends the active request after removing empty runs and tags. @@ -845,9 +865,24 @@ def flush(self): Starts a new, empty active request. """ request = self._request - request.time_series_data = list(self._tag_to_time_series_data.values()) - _prune_empty_time_series(request) - if not request.time_series_data: + has_data = False + for ( + run_name, + tag_to_time_series_data, + ) in self._run_to_tag_to_time_series_data.items(): + r = tensorboard_service.WriteTensorboardRunDataRequest( + tensorboard_run=self._one_platform_resource_manager.get_run_resource_name( + run_name + ) + ) + r.time_series_data = list(tag_to_time_series_data.values()) + _prune_empty_time_series(r) + if not r.time_series_data: + continue + request.write_run_data_requests.extend([r]) + has_data = True + + if not has_data: return self._rpc_rate_limiter.tick() @@ -855,9 +890,9 @@ def flush(self): with _request_logger(request): with self._get_tracker(): try: - self._api.write_tensorboard_run_data( - tensorboard_run=self._run_resource_id, - time_series_data=request.time_series_data, + self._api.write_tensorboard_experiment_data( + tensorboard_experiment=request.tensorboard_experiment, + write_run_data_requests=request.write_run_data_requests, ) except grpc.RpcError as e: if ( @@ -870,7 +905,7 @@ def flush(self): self._new_request() def _create_time_series_data( - self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata + self, run_name: str, tag_name: str, metadata: tf.compat.v1.SummaryMetadata ) -> tensorboard_data.TimeSeriesData: """Adds a time_series for the tag_name, if there's space. @@ -884,25 +919,31 @@ def _create_time_series_data( _OutOfSpaceError: If adding the tag would exceed the remaining request budget. """ + time_series_resource_name = self._one_platform_resource_manager.get_time_series_resource_name( + run_name, + tag_name, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=tag_name, + value_type=self._value_type, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ) + time_series_data_proto = tensorboard_data.TimeSeriesData( - tensorboard_time_series_id=self._time_series_resource_manager.get_or_create( - tag_name, - lambda: tensorboard_time_series.TensorboardTimeSeries( - display_name=tag_name, - value_type=self._value_type, - plugin_name=metadata.plugin_data.plugin_name, - plugin_data=metadata.plugin_data.content, - ), - ).name.split("/")[-1], + tensorboard_time_series_id=time_series_resource_name.split("/")[-1], value_type=self._value_type, ) self._byte_budget_manager.add_time_series(time_series_data_proto) - self._tag_to_time_series_data[tag_name] = time_series_data_proto + self._run_to_tag_to_time_series_data[run_name][ + tag_name + ] = time_series_data_proto return time_series_data_proto def _create_point( self, + run_name: str, time_series_proto: tensorboard_data.TimeSeriesData, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, @@ -920,7 +961,7 @@ def _create_point( _OutOfSpaceError: If adding the point would exceed the remaining request budget. """ - point = self._create_data_point(event, value, metadata) + point = self._create_data_point(run_name, event, value, metadata) if not self._validate(point, event, value): return @@ -951,6 +992,7 @@ def _value_type(cls,) -> tensorboard_time_series.TensorboardTimeSeries.ValueType @abc.abstractmethod def _create_data_point( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, @@ -988,24 +1030,30 @@ class _ScalarBatchedRequestSender(_BaseBatchedRequestSender): def __init__( self, - run_resource_id: str, + experiment_resource_id: str, api: TensorboardServiceClient, rpc_rate_limiter: util.RateLimiter, max_request_size: int, tracker: upload_tracker.UploadTracker, + one_platform_resource_manager: _OnePlatformResourceManager, ): """Constructor for _ScalarBatchedRequestSender. Args: - run_resource_id: The resource id for the run with the following format - projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + experiment_resource_id: The resource id for the experiment with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} api: TensorboardServiceStub rpc_rate_limiter: until.RateLimiter to limit rate of this request sender max_request_size: max number of bytes to send tracker: """ super().__init__( - run_resource_id, api, rpc_rate_limiter, max_request_size, tracker + experiment_resource_id, + api, + rpc_rate_limiter, + max_request_size, + tracker, + one_platform_resource_manager, ) def _get_tracker(self) -> ContextManager: @@ -1013,6 +1061,7 @@ def _get_tracker(self) -> ContextManager: def _create_data_point( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, @@ -1044,25 +1093,31 @@ class _TensorBatchedRequestSender(_BaseBatchedRequestSender): def __init__( self, - run_resource_id: str, + experiment_resource_id: str, api: TensorboardServiceClient, rpc_rate_limiter: util.RateLimiter, max_request_size: int, max_tensor_point_size: int, tracker: upload_tracker.UploadTracker, + one_platform_resource_manager: _OnePlatformResourceManager, ): """Constructor for _TensorBatchedRequestSender. Args: - run_resource_id: The resource id for the run with the following format - projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run} + experiment_resource_id: The resource id for the experiment with the following format + projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment} api: TensorboardServiceStub rpc_rate_limiter: until.RateLimiter to limit rate of this request sender max_request_size: max number of bytes to send tracker: """ super().__init__( - run_resource_id, api, rpc_rate_limiter, max_request_size, tracker + experiment_resource_id, + api, + rpc_rate_limiter, + max_request_size, + tracker, + one_platform_resource_manager, ) self._max_tensor_point_size = max_tensor_point_size @@ -1084,6 +1139,7 @@ def _get_tracker(self) -> ContextManager: def _create_data_point( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, @@ -1152,7 +1208,9 @@ def __init__(self, max_bytes: int): self._byte_budget = None # type: int self._max_bytes = max_bytes - def reset(self, base_request: tensorboard_service.WriteTensorboardRunDataRequest): + def reset( + self, base_request: tensorboard_service.WriteTensorboardExperimentDataRequest + ): """Resets the byte budget and calculates the cost of the base request. Args: @@ -1235,7 +1293,7 @@ class _BlobRequestSender(_BaseBatchedRequestSender): def __init__( self, - run_resource_id: str, + experiment_resource_id: str, api: TensorboardServiceClient, rpc_rate_limiter: util.RateLimiter, max_blob_request_size: int, @@ -1243,9 +1301,15 @@ def __init__( blob_storage_bucket: storage.Bucket, blob_storage_folder: str, tracker: upload_tracker.UploadTracker, + one_platform_resource_manager: _OnePlatformResourceManager, ): super().__init__( - run_resource_id, api, rpc_rate_limiter, max_blob_request_size, tracker + experiment_resource_id, + api, + rpc_rate_limiter, + max_blob_request_size, + tracker, + one_platform_resource_manager, ) self._max_blob_size = max_blob_size self._bucket = blob_storage_bucket @@ -1260,6 +1324,7 @@ def _get_tracker(self) -> ContextManager: def _create_data_point( self, + run_name: str, event: tf.compat.v1.Event, value: tf.compat.v1.Summary.Value, metadata: tf.compat.v1.SummaryMetadata, @@ -1270,25 +1335,25 @@ def _create_data_point( "A blob sequence must be represented as a rank-1 Tensor. " "Provided data has rank %d, for run %s, tag %s, step %s ('%s' plugin) .", blobs.ndim, - self._run_resource_id, + run_name, value.tag, event.step, metadata.plugin_data.plugin_name, ) return None - time_series_proto = self._time_series_resource_manager.get_or_create( - value.tag, - lambda: tensorboard_time_series.TensorboardTimeSeries( - display_name=value.tag, - value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, - plugin_name=metadata.plugin_data.plugin_name, - plugin_data=metadata.plugin_data.content, - ), - ) m = re.match( ".*/tensorboards/(.*)/experiments/(.*)/runs/(.*)/timeSeries/(.*)", - time_series_proto.name, + self._one_platform_resource_manager.get_time_series_resource_name( + run_name, + value.tag, + lambda: tensorboard_time_series.TensorboardTimeSeries( + display_name=value.tag, + value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE, + plugin_name=metadata.plugin_data.plugin_name, + plugin_data=metadata.plugin_data.content, + ), + ), ) blob_path_prefix = "tensorboard-{}/{}/{}/{}".format(m[1], m[2], m[3], m[4]) blob_path_prefix = ( @@ -1343,7 +1408,7 @@ def _send_blob(self, blob, blob_path_prefix): @contextlib.contextmanager -def _request_logger(request: tensorboard_service.WriteTensorboardRunDataRequest): +def _request_logger(request: tensorboard_service.WriteTensorboardExperimentDataRequest): """Context manager to log request size and duration.""" upload_start_time = time.time() request_bytes = request._pb.ByteSize() # pylint: disable=protected-access diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py index ebd4aa5147..7c868fcfb7 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_main.py +++ b/google/cloud/aiplatform/tensorboard/uploader_main.py @@ -146,10 +146,7 @@ def main(argv): tb_uploader.get_experiment_resource_name().replace("/", "+"), ) ) - if FLAGS.one_shot: - tb_uploader._upload_once() # pylint: disable=protected-access - else: - tb_uploader.start_uploading() + tb_uploader.start_uploading() def flags_parser(args): diff --git a/tests/unit/aiplatform/test_uploader.py b/tests/unit/aiplatform/test_uploader.py index fe198d8cde..0267833dce 100644 --- a/tests/unit/aiplatform/test_uploader.py +++ b/tests/unit/aiplatform/test_uploader.py @@ -254,14 +254,17 @@ def _create_dispatcher( def _create_scalar_request_sender( - run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT + experiment_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT ): if api is _USE_DEFAULT: api = _create_mock_client() if max_request_size is _USE_DEFAULT: max_request_size = 128000 return uploader_lib._ScalarBatchedRequestSender( - run_resource_id=run_resource_id, + experiment_resource_id=experiment_resource_id, + one_platform_resource_manager=uploader_lib._OnePlatformResourceManager( + experiment_resource_id, api + ), api=api, rpc_rate_limiter=util.RateLimiter(0), max_request_size=max_request_size, @@ -505,14 +508,14 @@ def test_start_uploading_scalars(self): uploader, "_logdir_loader", mock_logdir_loader ), self.assertRaises(AbortUploadError): uploader.start_uploading() - self.assertEqual(10, mock_client.write_tensorboard_run_data.call_count) - self.assertEqual(10, mock_rate_limiter.tick.call_count) + self.assertEqual(5, mock_client.write_tensorboard_experiment_data.call_count) + self.assertEqual(5, mock_rate_limiter.tick.call_count) self.assertEqual(0, mock_tensor_rate_limiter.tick.call_count) self.assertEqual(0, mock_blob_rate_limiter.tick.call_count) # Check upload tracker calls. self.assertEqual(mock_tracker.send_tracker.call_count, 2) - self.assertEqual(mock_tracker.scalars_tracker.call_count, 10) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 5) self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) self.assertEqual(mock_tracker.blob_tracker.call_count, 0) @@ -552,12 +555,12 @@ def test_start_uploading_scalars_one_shot(self): with mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader): uploader.start_uploading() - self.assertEqual(4, mock_client.write_tensorboard_run_data.call_count) - self.assertEqual(4, mock_rate_limiter.tick.call_count) + self.assertEqual(2, mock_client.write_tensorboard_experiment_data.call_count) + self.assertEqual(2, mock_rate_limiter.tick.call_count) # Check upload tracker calls. self.assertEqual(mock_tracker.send_tracker.call_count, 1) - self.assertEqual(mock_tracker.scalars_tracker.call_count, 4) + self.assertEqual(mock_tracker.scalars_tracker.call_count, 2) self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1) self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) self.assertEqual(mock_tracker.blob_tracker.call_count, 0) @@ -568,7 +571,7 @@ def test_upload_empty_logdir(self): uploader = _create_uploader(mock_client, logdir) uploader.create_experiment() uploader._upload_once() - mock_client.write_tensorboard_run_data.assert_not_called() + mock_client.write_tensorboard_experiment_data.assert_not_called() def test_upload_polls_slowly_once_done(self): class SuccessError(Exception): @@ -601,9 +604,9 @@ def test_upload_swallows_rpc_failure(self): uploader = _create_uploader(mock_client, logdir) uploader.create_experiment() error = _grpc_error(grpc.StatusCode.INTERNAL, "Failure") - mock_client.write_tensorboard_run_data.side_effect = error + mock_client.write_tensorboard_experiment_data.side_effect = error uploader._upload_once() - mock_client.write_tensorboard_run_data.assert_called_once() + mock_client.write_tensorboard_experiment_data.assert_called_once() def test_upload_full_logdir(self): logdir = self.get_temp_dir() @@ -644,11 +647,11 @@ def test_upload_full_logdir(self): self.assertEqual("scalars", request.plugin_name) self.assertEqual(b"12345", request.plugin_data) - self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count) + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list request1, request2 = ( - call_args_list[0][1]["time_series_data"], - call_args_list[1][1]["time_series_data"], + call_args_list[0][1]["write_run_data_requests"][0].time_series_data, + call_args_list[0][1]["write_run_data_requests"][1].time_series_data, ) _clear_wall_times(request1) _clear_wall_times(request2) @@ -680,7 +683,7 @@ def test_upload_full_logdir(self): self.assertProtoEquals(expected_request1[1], request1[1]) self.assertProtoEquals(expected_request2[0], request2[0]) - mock_client.write_tensorboard_run_data.reset_mock() + mock_client.write_tensorboard_experiment_data.reset_mock() # Second round writer.add_test_summary("foo", simple_value=10.0, step=5) @@ -690,11 +693,11 @@ def test_upload_full_logdir(self): writer_b.add_test_summary("xyz", simple_value=12.0, step=1) writer_b.flush() uploader._upload_once() - self.assertEqual(2, mock_client.write_tensorboard_run_data.call_count) - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count) + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list request3, request4 = ( - call_args_list[0][1]["time_series_data"], - call_args_list[1][1]["time_series_data"], + call_args_list[0][1]["write_run_data_requests"][0].time_series_data, + call_args_list[0][1]["write_run_data_requests"][1].time_series_data, ) _clear_wall_times(request3) _clear_wall_times(request4) @@ -720,11 +723,11 @@ def test_upload_full_logdir(self): self.assertProtoEquals(expected_request3[0], request3[0]) self.assertProtoEquals(expected_request3[1], request3[1]) self.assertProtoEquals(expected_request4[0], request4[0]) - mock_client.write_tensorboard_run_data.reset_mock() + mock_client.write_tensorboard_experiment_data.reset_mock() # Empty third round uploader._upload_once() - mock_client.write_tensorboard_run_data.assert_not_called() + mock_client.write_tensorboard_experiment_data.assert_not_called() def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero(self): mock_client = _create_mock_client() @@ -829,9 +832,9 @@ def create_time_series(tensorboard_time_series, parent=None): actual_graph_def = graph_pb2.GraphDef.FromString(request) self.assertProtoEquals(expected_graph_def, actual_graph_def) - for call in mock_client.write_tensorboard_run_data.call_args_list: + for call in mock_client.write_tensorboard_experiment_data.call_args_list: kargs = call[1] - time_series_data = kargs["time_series_data"] + time_series_data = kargs["write_run_data_requests"][0].time_series_data self.assertEqual(len(time_series_data), 1) self.assertEqual( time_series_data[0].tensorboard_time_series_id, _TEST_TIME_SERIES_NAME @@ -845,7 +848,7 @@ def create_time_series(tensorboard_time_series, parent=None): self.assertEqual(mock_tracker.send_tracker.call_count, 2) self.assertEqual(mock_tracker.scalars_tracker.call_count, 0) self.assertEqual(mock_tracker.tensors_tracker.call_count, 0) - self.assertEqual(mock_tracker.blob_tracker.call_count, 15) + self.assertEqual(mock_tracker.blob_tracker.call_count, 12) def test_filter_graphs(self): # Three graphs: one short, one long, one corrupt. @@ -927,10 +930,13 @@ def _populate_run_from_events( allowed_plugins=allowed_plugins, ) builder.dispatch_requests({"": _apply_compat(events)}) - scalar_requests = mock_client.write_tensorboard_run_data.call_args_list + scalar_requests = mock_client.write_tensorboard_experiment_data.call_args_list if scalar_requests: self.assertLen(scalar_requests, 1) - self.assertLen(scalar_requests[0][1]["time_series_data"], n_scalar_events) + self.assertLen( + scalar_requests[0][1]["write_run_data_requests"][0].time_series_data, + n_scalar_events, + ) return scalar_requests def test_empty_events(self): @@ -1017,7 +1023,8 @@ def test_expands_multiple_values_in_event(self): ) self.assertProtoEquals( - time_series_data, call_args_list[0][1]["time_series_data"][0] + time_series_data, + call_args_list[0][1]["write_run_data_requests"][0].time_series_data[0], ) @@ -1025,20 +1032,24 @@ class ScalarBatchedRequestSenderTest(tf.test.TestCase): def _add_events(self, sender, events): for event in events: for value in event.summary.value: - sender.add_event(event, value, value.metadata) + sender.add_event(_TEST_RUN_NAME, event, value, value.metadata) def _add_events_and_flush(self, events, expected_n_time_series): mock_client = _create_mock_client() sender = _create_scalar_request_sender( - run_resource_id=_TEST_RUN_NAME, api=mock_client, + experiment_resource_id=_TEST_EXPERIMENT_NAME, api=mock_client, ) self._add_events(sender, events) sender.flush() - requests = mock_client.write_tensorboard_run_data.call_args_list + requests = mock_client.write_tensorboard_experiment_data.call_args_list self.assertLen(requests, 1) - self.assertLen(requests[0][1]["time_series_data"], expected_n_time_series) - return requests[0] + call_args = requests[0] + self.assertLen( + call_args[1]["write_run_data_requests"][0].time_series_data, + expected_n_time_series, + ) + return call_args def test_aggregation_by_tag(self): def make_event(step, wall_time, tag, value): @@ -1055,7 +1066,7 @@ def make_event(step, wall_time, tag, value): make_event(1, 6.0, "three", 66.0), ] call_args = self._add_events_and_flush(events, 3) - ts_data = call_args[1]["time_series_data"] + ts_data = call_args[1]["write_run_data_requests"][0].time_series_data tag_data = { ts.tensorboard_time_series_id: [ ( @@ -1081,9 +1092,9 @@ def test_v1_summary(self): event.summary.value.add(tag="foo", simple_value=5.0) call_args = self._add_events_and_flush(_apply_compat([event]), 1) - expected_call_args = mock.call( - tensorboard_run=_TEST_RUN_NAME, - time_series_data=[ + self.assertEqual(_TEST_EXPERIMENT_NAME, call_args[1]["tensorboard_experiment"]) + self.assertEqual( + [ tensorboard_data.TimeSeriesData( tensorboard_time_series_id="foo", value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, @@ -1096,8 +1107,8 @@ def test_v1_summary(self): ], ) ], + call_args[1]["write_run_data_requests"][0].time_series_data, ) - self.assertEqual(expected_call_args, call_args) def test_v1_summary_tb_summary(self): tf_summary = summary_v1.scalar_pb("foo", 5.0) @@ -1105,9 +1116,9 @@ def test_v1_summary_tb_summary(self): event = event_pb2.Event(step=1, wall_time=123.456, summary=tb_summary) call_args = self._add_events_and_flush(_apply_compat([event]), 1) - expected_call_args = mock.call( - tensorboard_run=_TEST_RUN_NAME, - time_series_data=[ + self.assertEqual(_TEST_EXPERIMENT_NAME, call_args[1]["tensorboard_experiment"]) + self.assertEqual( + [ tensorboard_data.TimeSeriesData( tensorboard_time_series_id="scalar_summary", value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, @@ -1120,8 +1131,8 @@ def test_v1_summary_tb_summary(self): ], ) ], + call_args[1]["write_run_data_requests"][0].time_series_data, ) - self.assertEqual(expected_call_args, call_args) def test_v2_summary(self): event = event_pb2.Event( @@ -1129,9 +1140,9 @@ def test_v2_summary(self): ) call_args = self._add_events_and_flush(_apply_compat([event]), 1) - expected_call_args = mock.call( - tensorboard_run=_TEST_RUN_NAME, - time_series_data=[ + self.assertEqual(_TEST_EXPERIMENT_NAME, call_args[1]["tensorboard_experiment"]) + self.assertEqual( + [ tensorboard_data.TimeSeriesData( tensorboard_time_series_id="foo", value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, @@ -1144,10 +1155,9 @@ def test_v2_summary(self): ], ) ], + call_args[1]["write_run_data_requests"][0].time_series_data, ) - self.assertEqual(expected_call_args, call_args) - def test_propagates_experiment_deletion(self): event = event_pb2.Event(step=1) event.summary.value.add(tag="foo", simple_value=1.0) @@ -1157,16 +1167,18 @@ def test_propagates_experiment_deletion(self): self._add_events(sender, _apply_compat([event])) error = _grpc_error(grpc.StatusCode.NOT_FOUND, "nope") - mock_client.write_tensorboard_run_data.side_effect = error + mock_client.write_tensorboard_experiment_data.side_effect = error with self.assertRaises(uploader_lib.ExperimentNotFoundError): sender.flush() def test_no_budget_for_base_request(self): mock_client = _create_mock_client() - long_run_id = "A" * 12 + long_experiment_id = "A" * 12 with self.assertRaises(uploader_lib._OutOfSpaceError) as cm: _create_scalar_request_sender( - run_resource_id=long_run_id, api=mock_client, max_request_size=12, + experiment_resource_id=long_experiment_id, + api=mock_client, + max_request_size=12, ) self.assertEqual(str(cm.exception), "Byte budget too small for base request") @@ -1207,46 +1219,48 @@ def test_break_at_run_boundary(self): self._add_events(sender_2, _apply_compat([event_2])) sender_1.flush() sender_2.flush() - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list for call_args in call_args_list: - _clear_wall_times(call_args[1]["time_series_data"]) + _clear_wall_times( + call_args[1]["write_run_data_requests"][0].time_series_data + ) # Expect two calls despite a single explicit call to flush(). expected = [ - mock.call( - tensorboard_run=long_run_1, - time_series_data=[ - tensorboard_data.TimeSeriesData( - tensorboard_time_series_id="foo", - value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, - values=[ - tensorboard_data.TimeSeriesDataPoint( - step=1, scalar=tensorboard_data.Scalar(value=1.0) - ) - ], - ) - ], - ), - mock.call( - tensorboard_run=long_run_2, - time_series_data=[ - tensorboard_data.TimeSeriesData( - tensorboard_time_series_id="bar", - value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, - values=[ - tensorboard_data.TimeSeriesDataPoint( - step=2, scalar=tensorboard_data.Scalar(value=-2.0) - ) - ], - ) - ], - ), + [ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="foo", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=1, scalar=tensorboard_data.Scalar(value=1.0) + ) + ], + ) + ], + [ + tensorboard_data.TimeSeriesData( + tensorboard_time_series_id="bar", + value_type=tensorboard_time_series_type.TensorboardTimeSeries.ValueType.SCALAR, + values=[ + tensorboard_data.TimeSeriesDataPoint( + step=2, scalar=tensorboard_data.Scalar(value=-2.0) + ) + ], + ) + ], ] - self.assertEqual(expected[0], call_args_list[0]) - self.assertEqual(expected[1], call_args_list[1]) + self.assertEqual( + expected[0], + call_args_list[0][1]["write_run_data_requests"][0].time_series_data, + ) + self.assertEqual( + expected[1], + call_args_list[1][1]["write_run_data_requests"][0].time_series_data, + ) def test_break_at_tag_boundary(self): mock_client = _create_mock_client() @@ -1267,9 +1281,9 @@ def test_break_at_tag_boundary(self): ) self._add_events(sender, _apply_compat([event])) sender.flush() - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list - request1 = call_args_list[0][1]["time_series_data"] + request1 = call_args_list[0][1]["write_run_data_requests"][0].time_series_data _clear_wall_times(request1) # Convenience helpers for constructing expected requests. @@ -1310,10 +1324,12 @@ def test_break_at_scalar_point_boundary(self): ) self._add_events(sender, _apply_compat(events)) sender.flush() - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list for call_args in call_args_list: - _clear_wall_times(call_args[1]["time_series_data"]) + _clear_wall_times( + call_args[1]["write_run_data_requests"][0].time_series_data + ) self.assertGreater(len(call_args_list), 1) self.assertLess(len(call_args_list), point_count) @@ -1325,9 +1341,12 @@ def test_break_at_scalar_point_boundary(self): total_points_in_result = 0 for call_args in call_args_list: - self.assertLen(call_args[1]["time_series_data"], 1) - self.assertEqual(call_args[1]["tensorboard_run"], "train") - time_series_data = call_args[1]["time_series_data"][0] + self.assertLen( + call_args[1]["write_run_data_requests"][0].time_series_data, 1 + ) + time_series_data = call_args[1]["write_run_data_requests"][ + 0 + ].time_series_data[0] self.assertEqual(time_series_data.tensorboard_time_series_id, "loss") for point in time_series_data.values: self.assertEqual(point.step, total_points_in_result) @@ -1359,10 +1378,10 @@ def mock_add_point(byte_budget_manager_self, point): self._add_events(sender, _apply_compat([event_2])) sender.flush() - call_args_list = mock_client.write_tensorboard_run_data.call_args_list + call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list request1, request2 = ( - call_args_list[0][1]["time_series_data"], - call_args_list[1][1]["time_series_data"], + call_args_list[0][1]["write_run_data_requests"][0].time_series_data, + call_args_list[1][1]["write_run_data_requests"][0].time_series_data, ) _clear_wall_times(request1) _clear_wall_times(request2) @@ -1404,13 +1423,19 @@ def test_wall_time_precision(self): datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( _timestamp_pb(1567808404765432119) ), - call_args[1]["time_series_data"][0].values[0].wall_time, + call_args[1]["write_run_data_requests"][0] + .time_series_data[0] + .values[0] + .wall_time, ) self.assertEqual( datetime_helpers.DatetimeWithNanoseconds.from_timestamp_pb( _timestamp_pb(1000000002) ), - call_args[1]["time_series_data"][0].values[1].wall_time, + call_args[1]["write_run_data_requests"][0] + .time_series_data[0] + .values[1] + .wall_time, ) @@ -1451,7 +1476,7 @@ def _extract_tag_counts(call_args_list): return { ts_data.tensorboard_time_series_id: len(ts_data.values) for call_args in call_args_list - for ts_data in call_args[1]["time_series_data"] + for ts_data in call_args[1]["write_run_data_requests"][0].time_series_data }