Skip to content

Commit

Permalink
feat: Update tensorboard uploader to use Dispatcher for handling diff…
Browse files Browse the repository at this point in the history
…erent event types (#651)

Refactoring for tensorboard uploader so that it can use additional senders, which do not process typical event files, in the future. This is some initial work before adding the tensorboard profiler event sender, and this should have no impact to current functionality. 

Fixes #519
  • Loading branch information
mkovalski committed Sep 2, 2021
1 parent 4f0c18e commit d20b520
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 88 deletions.
230 changes: 147 additions & 83 deletions google/cloud/aiplatform/tensorboard/uploader.py
Expand Up @@ -30,7 +30,6 @@
Generator,
Iterable,
Optional,
Tuple,
ContextManager,
)
import uuid
Expand Down Expand Up @@ -118,6 +117,15 @@
logger.setLevel(logging.WARNING)


class RequestSender(object):
"""A base class for additional request sender objects.
Currently just used for typing.
"""

pass


class TensorBoardUploader(object):
"""Uploads a TensorBoard logdir to TensorBoard.gcp."""

Expand Down Expand Up @@ -216,7 +224,7 @@ def __init__(
self._description = description
self._verbosity = verbosity
self._one_shot = one_shot
self._request_sender = None
self._dispatcher = None
if logdir_poll_rate_limiter is None:
self._logdir_poll_rate_limiter = util.RateLimiter(
_MIN_LOGDIR_POLL_INTERVAL_SECS
Expand Down Expand Up @@ -296,7 +304,7 @@ def create_experiment(self):

experiment = self._create_or_get_experiment()
self._experiment = experiment
self._request_sender = _BatchedRequestSender(
request_sender = _BatchedRequestSender(
self._experiment.name,
self._api,
allowed_plugins=self._allowed_plugins,
Expand All @@ -309,6 +317,26 @@ def create_experiment(self):
tracker=self._tracker,
)

additional_senders = self._create_additional_senders()

self._dispatcher = _Dispatcher(
request_sender=request_sender, additional_senders=additional_senders,
)

def _create_additional_senders(self) -> Dict[str, RequestSender]:
"""Create any additional senders for non traditional event files.
Some items that are used for plugins do not process typical event files,
but need to be searched for and stored so that they can be used by the
plugin. If there are any items that cannot be searched for via the
`_BatchedRequestSender`, add them here.
Returns:
Mapping from plugin name to Sender.
"""
additional_senders = {}
return additional_senders

def get_experiment_resource_name(self):
return self._experiment.name

Expand All @@ -320,7 +348,7 @@ def start_uploading(self):
ExperimentNotFoundError: If the experiment is deleted during the
course of the upload.
"""
if self._request_sender is None:
if self._dispatcher is None:
raise RuntimeError("Must call create_experiment() before start_uploading()")
while True:
self._logdir_poll_rate_limiter.tick()
Expand Down Expand Up @@ -348,7 +376,7 @@ def _upload_once(self):
self._run_name_prefix + k: v for k, v in run_to_events.items()
}
with self._tracker.send_tracker():
self._request_sender.send_requests(run_to_events)
self._dispatcher.dispatch_requests(run_to_events)


class ExperimentNotFoundError(RuntimeError):
Expand Down Expand Up @@ -453,87 +481,89 @@ def __init__(
tracker=self._tracker,
)

def send_requests(
self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]]
def send_request(
self,
run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
):
"""Accepts a stream of TF events and sends batched write RPCs.
Each sent request will be batched, the size of each batch depending on
the type of data (Scalar vs Tensor vs Blob) being sent.
Args:
run_to_events: Mapping from run name to generator of `tf.compat.v1.Event`
values, as returned by `LogdirLoader.get_run_events`.
run_name: Name of the run retrieved by `LogdirLoader.get_run_events`
event: The `tf.compat.v1.Event` for the run
value: A single `tf.compat.v1.Summary.Value` from the event, where
there can be multiple values per event.
Raises:
RuntimeError: If no progress can be made because even a single
point is too large (say, due to a gigabyte-long tag name).
"""

for (run_name, event, value) in self._run_values(run_to_events):
time_series_key = (run_name, value.tag)

# The metadata for a time series is memorized on the first event.
# If later events arrive with a mismatching plugin_name, they are
# ignored with a warning.
metadata = self._tag_metadata.get(time_series_key)
first_in_time_series = False
if metadata is None:
first_in_time_series = True
metadata = value.metadata
self._tag_metadata[time_series_key] = metadata

plugin_name = metadata.plugin_data.plugin_name
if value.HasField("metadata") and (
plugin_name != value.metadata.plugin_data.plugin_name
):
logger.warning(
"Mismatching plugin names for %s. Expected %s, found %s.",
time_series_key = (run_name, value.tag)

# The metadata for a time series is memorized on the first event.
# If later events arrive with a mismatching plugin_name, they are
# ignored with a warning.
metadata = self._tag_metadata.get(time_series_key)
first_in_time_series = False
if metadata is None:
first_in_time_series = True
metadata = value.metadata
self._tag_metadata[time_series_key] = metadata

plugin_name = metadata.plugin_data.plugin_name
if value.HasField("metadata") and (
plugin_name != value.metadata.plugin_data.plugin_name
):
logger.warning(
"Mismatching plugin names for %s. Expected %s, found %s.",
time_series_key,
metadata.plugin_data.plugin_name,
value.metadata.plugin_data.plugin_name,
)
return
if plugin_name not in self._allowed_plugins:
if first_in_time_series:
logger.info(
"Skipping time series %r with unsupported plugin name %r",
time_series_key,
metadata.plugin_data.plugin_name,
value.metadata.plugin_data.plugin_name,
)
continue
if plugin_name not in self._allowed_plugins:
if first_in_time_series:
logger.info(
"Skipping time series %r with unsupported plugin name %r",
time_series_key,
plugin_name,
)
continue
self._tracker.add_plugin_name(plugin_name)
# If this is the first time we've seen this run create a new run resource
# and an associated request sender.
if run_name not in self._run_to_run_resource:
self._create_or_get_run_resource(run_name)
self._run_to_request_sender[
run_name
] = self._scalar_request_sender_factory(
self._run_to_run_resource[run_name].name
)
self._run_to_tensor_request_sender[
run_name
] = self._tensor_request_sender_factory(
self._run_to_run_resource[run_name].name
)
self._run_to_blob_request_sender[
run_name
] = self._blob_request_sender_factory(
self._run_to_run_resource[run_name].name
plugin_name,
)
return
self._tracker.add_plugin_name(plugin_name)
# If this is the first time we've seen this run create a new run resource
# and an associated request sender.
if run_name not in self._run_to_run_resource:
self._create_or_get_run_resource(run_name)
self._run_to_request_sender[run_name] = self._scalar_request_sender_factory(
self._run_to_run_resource[run_name].name
)
self._run_to_tensor_request_sender[
run_name
] = self._tensor_request_sender_factory(
self._run_to_run_resource[run_name].name
)
self._run_to_blob_request_sender[
run_name
] = self._blob_request_sender_factory(
self._run_to_run_resource[run_name].name
)

if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
self._run_to_request_sender[run_name].add_event(event, value, metadata)
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
self._run_to_tensor_request_sender[run_name].add_event(
event, value, metadata
)
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
self._run_to_blob_request_sender[run_name].add_event(
event, value, metadata
)
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
self._run_to_request_sender[run_name].add_event(event, value, metadata)
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
self._run_to_tensor_request_sender[run_name].add_event(
event, value, metadata
)
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
self._run_to_blob_request_sender[run_name].add_event(event, value, metadata)

def flush(self):
"""Flushes any events that have been stored."""
for scalar_request_sender in self._run_to_request_sender.values():
scalar_request_sender.flush()

Expand Down Expand Up @@ -577,12 +607,52 @@ def _create_or_get_run_resource(self, run_name: str):

self._run_to_run_resource[run_name] = tb_run

def _run_values(

class _Dispatcher(object):
"""Dispatch the requests to the correct request senders."""

def __init__(
self,
request_sender: _BatchedRequestSender,
additional_senders: Optional[Dict[str, RequestSender]] = None,
):
"""Construct a _Dispatcher object for the TensorboardUploader.
Args:
request_sender: A `_BatchedRequestSender` for handling events.
additional_senders: A dictionary mapping a plugin name to additional
Senders.
"""
self._request_sender = request_sender

if not additional_senders:
additional_senders = {}
self._additional_senders = additional_senders

def _dispatch_additional_senders(
self, run_name: str,
):
"""Dispatch events to any additional senders.
These senders process non traditional event files for a specific plugin
and use a send_request function to process events.
Args:
run_name: String of current training run
"""
for key, sender in self._additional_senders.items():
sender.send_request(run_name)

def dispatch_requests(
self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]]
) -> Generator[
Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None
]:
"""Helper generator to create a single stream of work items.
):
"""Routes events to the appropriate sender.
Takes a mapping from strings to an event generator. The function routes
any events that should be handled by the `_BatchedRequestSender` and
non-traditional events that need to be handled differently, which are
stored as "_additional_senders". The `_request_sender` is then flushed
after all events are added.
Note that `dataclass_compat` may emit multiple variants of
the same event, for backwards compatibility. Thus this stream should
Expand All @@ -598,20 +668,14 @@ def _run_values(
Args:
run_to_events: Mapping from run name to generator of `tf.compat.v1.Event`
values, as returned by `LogdirLoader.get_run_events`.
Yields:
Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per
value.
"""
# Note that this join in principle has deletion anomalies: if the input
# stream contains runs with no events, or events with no values, we'll
# lose that information. This is not a problem: we would need to prune
# such data from the request anyway.
for (run_name, events) in run_to_events.items():
self._dispatch_additional_senders(run_name)
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
yield (run_name, event, value)
self._request_sender.send_request(run_name, event, value)
self._request_sender.flush()


class _TimeSeriesResourceManager(object):
Expand Down
17 changes: 12 additions & 5 deletions tests/unit/aiplatform/test_uploader.py
Expand Up @@ -213,8 +213,11 @@ def _create_uploader(
)


def _create_request_sender(
experiment_resource_name, api=None, allowed_plugins=_USE_DEFAULT
def _create_dispatcher(
experiment_resource_name,
api=None,
allowed_plugins=_USE_DEFAULT,
additional_senders={},
):
if api is _USE_DEFAULT:
api = _create_mock_client()
Expand All @@ -232,7 +235,7 @@ def _create_request_sender(
tensor_rpc_rate_limiter = util.RateLimiter(0)
blob_rpc_rate_limiter = util.RateLimiter(0)

return uploader_lib._BatchedRequestSender(
request_sender = uploader_lib._BatchedRequestSender(
experiment_resource_name=experiment_resource_name,
api=api,
allowed_plugins=allowed_plugins,
Expand All @@ -245,6 +248,10 @@ def _create_request_sender(
tracker=upload_tracker.UploadTracker(verbosity=0),
)

return uploader_lib._Dispatcher(
request_sender=request_sender, additional_senders=additional_senders,
)


def _create_scalar_request_sender(
run_resource_id, api=_USE_DEFAULT, max_request_size=_USE_DEFAULT
Expand Down Expand Up @@ -914,12 +921,12 @@ def _populate_run_from_events(
self, n_scalar_events, events, allowed_plugins=_USE_DEFAULT
):
mock_client = _create_mock_client()
builder = _create_request_sender(
builder = _create_dispatcher(
experiment_resource_name="123",
api=mock_client,
allowed_plugins=allowed_plugins,
)
builder.send_requests({"": _apply_compat(events)})
builder.dispatch_requests({"": _apply_compat(events)})
scalar_requests = mock_client.write_tensorboard_run_data.call_args_list
if scalar_requests:
self.assertLen(scalar_requests, 1)
Expand Down

0 comments on commit d20b520

Please sign in to comment.