/
uploader.py
1454 lines (1262 loc) · 55.3 KB
/
uploader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Uploads a TensorBoard logdir to TensorBoard.gcp."""
import abc
from collections import defaultdict
import functools
import logging
import os
import time
import re
from typing import (
Dict,
FrozenSet,
Generator,
Iterable,
Optional,
ContextManager,
Tuple,
)
import uuid
import grpc
from tensorboard.backend import process_graph
from tensorboard.backend.event_processing.plugin_event_accumulator import (
directory_loader,
)
from tensorboard.backend.event_processing.plugin_event_accumulator import (
event_file_loader,
)
from tensorboard.backend.event_processing.plugin_event_accumulator import io_wrapper
from tensorboard.compat.proto import graph_pb2
from tensorboard.compat.proto import summary_pb2
from tensorboard.compat.proto import types_pb2
from tensorboard.plugins.graph import metadata as graph_metadata
from tensorboard.uploader import logdir_loader
from tensorboard.uploader import upload_tracker
from tensorboard.uploader import util
from tensorboard.uploader.proto import server_info_pb2
from tensorboard.util import tb_logging
from tensorboard.util import tensor_util
import tensorflow as tf
from google.api_core import exceptions
from google.cloud import storage
from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1
from google.cloud.aiplatform.compat.types import (
tensorboard_data_v1beta1 as tensorboard_data,
)
from google.cloud.aiplatform.compat.types import (
tensorboard_experiment_v1beta1 as tensorboard_experiment,
)
from google.cloud.aiplatform.compat.types import (
tensorboard_service_v1beta1 as tensorboard_service,
)
from google.cloud.aiplatform.compat.types import (
tensorboard_time_series_v1beta1 as tensorboard_time_series,
)
from google.cloud.aiplatform.tensorboard import uploader_utils
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
from google.protobuf import message
from google.protobuf import timestamp_pb2 as timestamp
TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient
# Minimum length of a logdir polling cycle in seconds. Shorter cycles will
# sleep to avoid spinning over the logdir, which isn't great for disks and can
# be expensive for network file systems.
_MIN_LOGDIR_POLL_INTERVAL_SECS = 1
# Maximum length of a base-128 varint as used to encode a 64-bit value
# (without the "msb of last byte is bit 63" optimization, to be
# compatible with protobuf and golang varints).
_MAX_VARINT64_LENGTH_BYTES = 10
# Default minimum interval between initiating WriteTensorbordRunData RPCs in
# milliseconds.
_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 10
# Default maximum WriteTensorbordRunData request size in bytes.
_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 128 * (2 ** 10) # 128KiB
# Default minimum interval between initiating WriteTensorbordRunData RPCs in
# milliseconds.
_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL = 10
# Default minimum interval between initiating WriteTensorbordRunData RPCs in
# milliseconds.
_DEFAULT_MIN_BLOB_REQUEST_INTERVAL = 10
# Default maximum WriteTensorbordRunData request size in bytes.
_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2 ** 10) # 512KiB
_DEFAULT_MAX_BLOB_REQUEST_SIZE = 128 * (2 ** 10) # 24KiB
# Default maximum tensor point size in bytes.
_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2 ** 10) # 16KiB
_DEFAULT_MAX_BLOB_SIZE = 10 * (2 ** 30) # 10GiB
logger = tb_logging.get_logger()
logger.setLevel(logging.WARNING)
class RequestSender(object):
"""A base class for additional request sender objects.
Currently just used for typing.
"""
pass
class TensorBoardUploader(object):
"""Uploads a TensorBoard logdir to TensorBoard.gcp."""
def __init__(
self,
experiment_name: str,
tensorboard_resource_name: str,
blob_storage_bucket: storage.Bucket,
blob_storage_folder: str,
writer_client: TensorboardServiceClient,
logdir: str,
allowed_plugins: FrozenSet[str],
experiment_display_name: Optional[str] = None,
upload_limits: Optional[server_info_pb2.UploadLimits] = None,
logdir_poll_rate_limiter: Optional[util.RateLimiter] = None,
rpc_rate_limiter: Optional[util.RateLimiter] = None,
tensor_rpc_rate_limiter: Optional[util.RateLimiter] = None,
blob_rpc_rate_limiter: Optional[util.RateLimiter] = None,
description: Optional[str] = None,
verbosity: int = 1,
one_shot: bool = False,
event_file_inactive_secs: Optional[int] = None,
run_name_prefix=None,
):
"""Constructs a TensorBoardUploader.
Args:
experiment_name: Name of this experiment. Unique to the given
tensorboard_resource_name.
tensorboard_resource_name: Name of the Tensorboard resource with this
format
projects/{project}/locations/{location}/tensorboards/{tensorboard}
writer_client: a TensorBoardWriterService stub instance
logdir: path of the log directory to upload
experiment_display_name: The display name of the experiment.
allowed_plugins: collection of string plugin names; events will only be
uploaded if their time series's metadata specifies one of these plugin
names
upload_limits: instance of tensorboard.service.UploadLimits proto.
logdir_poll_rate_limiter: a `RateLimiter` to use to limit logdir polling
frequency, to avoid thrashing disks, especially on networked file
systems
rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency.
Note this limit applies at the level of single RPCs in the Scalar and
Tensor case, but at the level of an entire blob upload in the Blob
case-- which may require a few preparatory RPCs and a stream of chunks.
Note the chunk stream is internally rate-limited by backpressure from
the server, so it is not a concern that we do not explicitly rate-limit
within the stream here.
description: String description to assign to the experiment.
verbosity: Level of verbosity, an integer. Supported value: 0 - No upload
statistics is printed. 1 - Print upload statistics while uploading data
(default).
one_shot: Once uploading starts, upload only the existing data in the
logdir and then return immediately, instead of the default behavior of
continuing to listen for new data in the logdir and upload them when it
appears.
event_file_inactive_secs: Age in seconds of last write after which an
event file is considered inactive. If none then event file is never
considered inactive.
run_name_prefix: If present, all runs created by this invocation will have
their name prefixed by this value.
"""
self._experiment_name = experiment_name
self._experiment_display_name = experiment_display_name
self._tensorboard_resource_name = tensorboard_resource_name
self._blob_storage_bucket = blob_storage_bucket
self._blob_storage_folder = blob_storage_folder
self._api = writer_client
self._logdir = logdir
self._allowed_plugins = frozenset(allowed_plugins)
self._run_name_prefix = run_name_prefix
self._is_brand_new_experiment = False
self._upload_limits = upload_limits
if not self._upload_limits:
self._upload_limits = server_info_pb2.UploadLimits()
self._upload_limits.max_scalar_request_size = (
_DEFAULT_MAX_SCALAR_REQUEST_SIZE
)
self._upload_limits.min_scalar_request_interval = (
_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL
)
self._upload_limits.min_tensor_request_interval = (
_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL
)
self._upload_limits.max_tensor_request_size = (
_DEFAULT_MAX_TENSOR_REQUEST_SIZE
)
self._upload_limits.max_tensor_point_size = _DEFAULT_MAX_TENSOR_POINT_SIZE
self._upload_limits.min_blob_request_interval = (
_DEFAULT_MIN_BLOB_REQUEST_INTERVAL
)
self._upload_limits.max_blob_request_size = _DEFAULT_MAX_BLOB_REQUEST_SIZE
self._upload_limits.max_blob_size = _DEFAULT_MAX_BLOB_SIZE
self._description = description
self._verbosity = verbosity
self._one_shot = one_shot
self._dispatcher = None
self._additional_senders: Dict[str, uploader_utils.RequestSender] = {}
if logdir_poll_rate_limiter is None:
self._logdir_poll_rate_limiter = util.RateLimiter(
_MIN_LOGDIR_POLL_INTERVAL_SECS
)
else:
self._logdir_poll_rate_limiter = logdir_poll_rate_limiter
if rpc_rate_limiter is None:
self._rpc_rate_limiter = util.RateLimiter(
self._upload_limits.min_scalar_request_interval / 1000
)
else:
self._rpc_rate_limiter = rpc_rate_limiter
if tensor_rpc_rate_limiter is None:
self._tensor_rpc_rate_limiter = util.RateLimiter(
self._upload_limits.min_tensor_request_interval / 1000
)
else:
self._tensor_rpc_rate_limiter = tensor_rpc_rate_limiter
if blob_rpc_rate_limiter is None:
self._blob_rpc_rate_limiter = util.RateLimiter(
self._upload_limits.min_blob_request_interval / 1000
)
else:
self._blob_rpc_rate_limiter = blob_rpc_rate_limiter
def active_filter(secs):
return (
not bool(event_file_inactive_secs)
or secs + event_file_inactive_secs >= time.time()
)
directory_loader_factory = functools.partial(
directory_loader.DirectoryLoader,
loader_factory=event_file_loader.TimestampedEventFileLoader,
path_filter=io_wrapper.IsTensorFlowEventsFile,
active_filter=active_filter,
)
self._logdir_loader = logdir_loader.LogdirLoader(
self._logdir, directory_loader_factory
)
self._logdir_loader_pre_create = logdir_loader.LogdirLoader(
self._logdir, directory_loader_factory
)
self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity)
self._create_additional_senders()
def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperiment:
"""Create an experiment or get an experiment.
Attempts to create an experiment. If the experiment already exists and
creation fails then the experiment will be retrieved.
Returns:
The created or retrieved experiment.
"""
logger.info("Creating experiment")
tb_experiment = tensorboard_experiment.TensorboardExperiment(
description=self._description, display_name=self._experiment_display_name
)
try:
experiment = self._api.create_tensorboard_experiment(
parent=self._tensorboard_resource_name,
tensorboard_experiment=tb_experiment,
tensorboard_experiment_id=self._experiment_name,
)
self._is_brand_new_experiment = True
except exceptions.AlreadyExists:
logger.info("Creating experiment failed. Retrieving experiment.")
experiment_name = os.path.join(
self._tensorboard_resource_name, "experiments", self._experiment_name
)
experiment = self._api.get_tensorboard_experiment(name=experiment_name)
return experiment
def create_experiment(self):
"""Creates an Experiment for this upload session and returns the ID."""
experiment = self._create_or_get_experiment()
self._experiment = experiment
self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
self._experiment.name, self._api
)
self._request_sender = _BatchedRequestSender(
self._experiment.name,
self._api,
allowed_plugins=self._allowed_plugins,
upload_limits=self._upload_limits,
rpc_rate_limiter=self._rpc_rate_limiter,
tensor_rpc_rate_limiter=self._tensor_rpc_rate_limiter,
blob_rpc_rate_limiter=self._blob_rpc_rate_limiter,
blob_storage_bucket=self._blob_storage_bucket,
blob_storage_folder=self._blob_storage_folder,
one_platform_resource_manager=self._one_platform_resource_manager,
tracker=self._tracker,
)
# Update partials with experiment name
for sender in self._additional_senders.keys():
self._additional_senders[sender] = self._additional_senders[sender](
experiment_resource_name=self._experiment.name,
)
self._dispatcher = _Dispatcher(
request_sender=self._request_sender,
additional_senders=self._additional_senders,
)
def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
"""Create any additional senders for non traditional event files.
Some items that are used for plugins do not process typical event files,
but need to be searched for and stored so that they can be used by the
plugin. If there are any items that cannot be searched for via the
`_BatchedRequestSender`, add them here.
"""
if "profile" in self._allowed_plugins:
if not self._one_shot:
raise ValueError(
"Profile plugin currently only supported for one shot."
)
source_bucket = uploader_utils.get_source_bucket(self._logdir)
self._additional_senders["profile"] = functools.partial(
profile_uploader.ProfileRequestSender,
api=self._api,
upload_limits=self._upload_limits,
blob_rpc_rate_limiter=self._blob_rpc_rate_limiter,
blob_storage_bucket=self._blob_storage_bucket,
blob_storage_folder=self._blob_storage_folder,
source_bucket=source_bucket,
tracker=self._tracker,
logdir=self._logdir,
)
def get_experiment_resource_name(self):
return self._experiment.name
def start_uploading(self):
"""Blocks forever to continuously upload data from the logdir.
Raises:
RuntimeError: If `create_experiment` has not yet been called.
ExperimentNotFoundError: If the experiment is deleted during the
course of the upload.
"""
if self._dispatcher is None:
raise RuntimeError("Must call create_experiment() before start_uploading()")
if self._one_shot:
if self._is_brand_new_experiment:
self._pre_create_runs_and_time_series()
else:
logger.warning(
"Please consider uploading to a new experiment instead of "
"an existing one, as the former allows for better upload "
"performance."
)
while True:
self._logdir_poll_rate_limiter.tick()
self._upload_once()
if self._one_shot:
break
if self._one_shot and not self._tracker.has_data():
logger.warning(
"One-shot mode was used on a logdir (%s) "
"without any uploadable data" % self._logdir
)
def _pre_create_runs_and_time_series(self):
"""
Iterates though the log dir to collect TensorboardRuns and
TensorboardTimeSeries that need to be created, and creates them in batch
to speed up uploading later on.
"""
self._logdir_loader_pre_create.synchronize_runs()
run_to_events = self._logdir_loader_pre_create.get_run_events()
if self._run_name_prefix:
run_to_events = {
self._run_name_prefix + k: v for k, v in run_to_events.items()
}
run_names = []
run_tag_name_to_time_series_proto = {}
for (run_name, events) in run_to_events.items():
run_names.append(run_name)
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
metadata, is_valid = self._request_sender.get_metadata_and_validate(
run_name, value
)
if not is_valid:
continue
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
value_type = (
tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
)
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
value_type = (
tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR
)
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
value_type = (
tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE
)
run_tag_name_to_time_series_proto[
(run_name, value.tag)
] = tensorboard_time_series.TensorboardTimeSeries(
display_name=value.tag,
value_type=value_type,
plugin_name=metadata.plugin_data.plugin_name,
plugin_data=metadata.plugin_data.content,
)
self._one_platform_resource_manager.batch_create_runs(run_names)
self._one_platform_resource_manager.batch_create_time_series(
run_tag_name_to_time_series_proto
)
def _upload_once(self):
"""Runs one upload cycle, sending zero or more RPCs."""
logger.info("Starting an upload cycle")
sync_start_time = time.time()
self._logdir_loader.synchronize_runs()
sync_duration_secs = time.time() - sync_start_time
logger.info("Logdir sync took %.3f seconds", sync_duration_secs)
run_to_events = self._logdir_loader.get_run_events()
if self._run_name_prefix:
run_to_events = {
self._run_name_prefix + k: v for k, v in run_to_events.items()
}
with self._tracker.send_tracker():
self._dispatcher.dispatch_requests(run_to_events)
class PermissionDeniedError(RuntimeError):
pass
class ExperimentNotFoundError(RuntimeError):
pass
class _OutOfSpaceError(Exception):
"""Action could not proceed without overflowing request budget.
This is a signaling exception (like `StopIteration`) used internally
by `_*RequestSender`; it does not mean that anything has gone wrong.
"""
pass
class _BatchedRequestSender(object):
"""Helper class for building requests that fit under a size limit.
This class maintains stateful request builders for each of the possible
request types (scalars, tensors, and blobs). These accumulate batches
independently, each maintaining its own byte budget and emitting a request
when the batch becomes full. As a consequence, events of different types
will likely be sent to the backend out of order. E.g., in the extreme case,
a single tensor-flavored request may be sent only when the event stream is
exhausted, even though many more recent scalar events were sent earlier.
This class is not threadsafe. Use external synchronization if
calling its methods concurrently.
"""
def __init__(
self,
experiment_resource_name: str,
api: TensorboardServiceClient,
allowed_plugins: Iterable[str],
upload_limits: server_info_pb2.UploadLimits,
rpc_rate_limiter: util.RateLimiter,
tensor_rpc_rate_limiter: util.RateLimiter,
blob_rpc_rate_limiter: util.RateLimiter,
blob_storage_bucket: storage.Bucket,
blob_storage_folder: str,
one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
tracker: upload_tracker.UploadTracker,
):
"""Constructs _BatchedRequestSender for the given experiment resource.
Args:
experiment_resource_name: Name of the experiment resource of the form
projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
api: Tensorboard service stub used to interact with experiment resource.
allowed_plugins: The plugins supported by the Tensorboard.gcp resource.
upload_limits: Upload limits for for api calls.
rpc_rate_limiter: a `RateLimiter` to use to limit write RPC frequency.
Note this limit applies at the level of single RPCs in the Scalar and
Tensor case, but at the level of an entire blob upload in the Blob
case-- which may require a few preparatory RPCs and a stream of chunks.
Note the chunk stream is internally rate-limited by backpressure from
the server, so it is not a concern that we do not explicitly rate-limit
within the stream here.
one_platform_resource_manager: An instance of the One Platform
resource management class.
tracker: Upload tracker to track information about uploads.
"""
self._experiment_resource_name = experiment_resource_name
self._api = api
self._tag_metadata = {}
self._allowed_plugins = frozenset(allowed_plugins)
self._tracker = tracker
self._one_platform_resource_manager = one_platform_resource_manager
self._scalar_request_sender = _ScalarBatchedRequestSender(
experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=rpc_rate_limiter,
max_request_size=upload_limits.max_scalar_request_size,
tracker=self._tracker,
one_platform_resource_manager=self._one_platform_resource_manager,
)
self._tensor_request_sender = _TensorBatchedRequestSender(
experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=tensor_rpc_rate_limiter,
max_request_size=upload_limits.max_tensor_request_size,
max_tensor_point_size=upload_limits.max_tensor_point_size,
tracker=self._tracker,
one_platform_resource_manager=self._one_platform_resource_manager,
)
self._blob_request_sender = _BlobRequestSender(
experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=blob_rpc_rate_limiter,
max_blob_request_size=upload_limits.max_blob_request_size,
max_blob_size=upload_limits.max_blob_size,
blob_storage_bucket=blob_storage_bucket,
blob_storage_folder=blob_storage_folder,
tracker=self._tracker,
one_platform_resource_manager=self._one_platform_resource_manager,
)
def send_request(
self,
run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
):
"""Accepts a stream of TF events and sends batched write RPCs.
Each sent request will be batched, the size of each batch depending on
the type of data (Scalar vs Tensor vs Blob) being sent.
Args:
run_name: Name of the run retrieved by `LogdirLoader.get_run_events`
event: The `tf.compat.v1.Event` for the run
value: A single `tf.compat.v1.Summary.Value` from the event, where
there can be multiple values per event.
Raises:
RuntimeError: If no progress can be made because even a single
point is too large (say, due to a gigabyte-long tag name).
"""
metadata, is_valid = self.get_metadata_and_validate(run_name, value)
if not is_valid:
return
plugin_name = metadata.plugin_data.plugin_name
self._tracker.add_plugin_name(plugin_name)
if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
self._scalar_request_sender.add_event(run_name, event, value, metadata)
elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
self._tensor_request_sender.add_event(run_name, event, value, metadata)
elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
self._blob_request_sender.add_event(run_name, event, value, metadata)
def flush(self):
"""Flushes any events that have been stored."""
self._scalar_request_sender.flush()
self._tensor_request_sender.flush()
self._blob_request_sender.flush()
def get_metadata_and_validate(
self, run_name: str, value: tf.compat.v1.Summary.Value
) -> Tuple[tf.compat.v1.SummaryMetadata, bool]:
"""
:param run_name: Name of the run retrieved by
`LogdirLoader.get_run_events`
:param value: A single `tf.compat.v1.Summary.Value` from the event,
where there can be multiple values per event.
:return: (metadata, is_valid): a metadata derived from the value, and
whether the value itself is valid.
"""
time_series_key = (run_name, value.tag)
# The metadata for a time series is memorized on the first event.
# If later events arrive with a mismatching plugin_name, they are
# ignored with a warning.
metadata = self._tag_metadata.get(time_series_key)
first_in_time_series = False
if metadata is None:
first_in_time_series = True
metadata = value.metadata
self._tag_metadata[time_series_key] = metadata
plugin_name = metadata.plugin_data.plugin_name
if value.HasField("metadata") and (
plugin_name != value.metadata.plugin_data.plugin_name
):
logger.warning(
"Mismatching plugin names for %s. Expected %s, found %s.",
time_series_key,
metadata.plugin_data.plugin_name,
value.metadata.plugin_data.plugin_name,
)
return metadata, False
if plugin_name not in self._allowed_plugins:
if first_in_time_series:
logger.info(
"Skipping time series %r with unsupported plugin name %r",
time_series_key,
plugin_name,
)
return metadata, False
return metadata, True
class _Dispatcher(object):
"""Dispatch the requests to the correct request senders."""
def __init__(
self,
request_sender: _BatchedRequestSender,
additional_senders: Optional[Dict[str, uploader_utils.RequestSender]] = None,
):
"""Construct a _Dispatcher object for the TensorboardUploader.
Args:
request_sender: A `_BatchedRequestSender` for handling events.
additional_senders: A dictionary mapping a plugin name to additional
Senders.
"""
self._request_sender = request_sender
if not additional_senders:
additional_senders = {}
self._additional_senders = additional_senders
def _dispatch_additional_senders(
self, run_name: str,
):
"""Dispatch events to any additional senders.
These senders process non traditional event files for a specific plugin
and use a send_request function to process events.
Args:
run_name: String of current training run
"""
for key, sender in self._additional_senders.items():
sender.send_request(run_name)
def dispatch_requests(
self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]]
):
"""Routes events to the appropriate sender.
Takes a mapping from strings to an event generator. The function routes
any events that should be handled by the `_BatchedRequestSender` and
non-traditional events that need to be handled differently, which are
stored as "_additional_senders". The `_request_sender` is then flushed
after all events are added.
Note that `dataclass_compat` may emit multiple variants of
the same event, for backwards compatibility. Thus this stream should
be filtered to obtain the desired version of each event. Here, we
ignore any event that does not have a `summary` field.
Furthermore, the events emitted here could contain values that do not
have `metadata.data_class` set; these too should be ignored. In
`_send_summary_value(...)` above, we switch on `metadata.data_class`
and drop any values with an unknown (i.e., absent or unrecognized)
`data_class`.
Args:
run_to_events: Mapping from run name to generator of `tf.compat.v1.Event`
values, as returned by `LogdirLoader.get_run_events`.
"""
for (run_name, events) in run_to_events.items():
self._dispatch_additional_senders(run_name)
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
self._request_sender.send_request(run_name, event, value)
self._request_sender.flush()
class _BaseBatchedRequestSender(object):
"""Helper class for building requests that fit under a size limit.
This class accumulates a current request. `add_event(...)` may or may not
send the request (and start a new one). After all `add_event(...)` calls
are complete, a final call to `flush()` is needed to send the final request.
This class is not threadsafe. Use external synchronization if calling its
methods concurrently.
"""
def __init__(
self,
experiment_resource_id: str,
api: TensorboardServiceClient,
rpc_rate_limiter: util.RateLimiter,
max_request_size: int,
tracker: upload_tracker.UploadTracker,
one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
):
"""Constructor for _BaseBatchedRequestSender.
Args:
experiment_resource_id: The resource id for the experiment with the following format
projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
api: TensorboardServiceStub
rpc_rate_limiter: until.RateLimiter to limit rate of this request sender
max_request_size: max number of bytes to send
tracker:
"""
self._experiment_resource_id = experiment_resource_id
self._api = api
self._rpc_rate_limiter = rpc_rate_limiter
self._byte_budget_manager = _ByteBudgetManager(max_request_size)
self._tracker = tracker
self._one_platform_resource_manager = one_platform_resource_manager
# cache: map from Tensorboard tag to TimeSeriesData
# cleared whenever a new request is created
self._run_to_tag_to_time_series_data: Dict[
str, Dict[str, tensorboard_data.TimeSeriesData]
] = defaultdict(defaultdict)
self._new_request()
def _new_request(self):
"""Allocates a new request and refreshes the budget."""
self._request = tensorboard_service.WriteTensorboardExperimentDataRequest(
tensorboard_experiment=self._experiment_resource_id
)
self._run_to_tag_to_time_series_data.clear()
self._num_values = 0
self._byte_budget_manager.reset(self._request)
def add_event(
self,
run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
):
"""Attempts to add the given event to the current request.
If the event cannot be added to the current request because the byte
budget is exhausted, the request is flushed, and the event is added
to the next request.
Args:
event: The tf.compat.v1.Event event containing the value.
value: A scalar tf.compat.v1.Summary.Value.
metadata: SummaryMetadata of the event.
"""
try:
self._add_event_internal(run_name, event, value, metadata)
except _OutOfSpaceError:
self.flush()
# Try again. This attempt should never produce OutOfSpaceError
# because we just flushed.
try:
self._add_event_internal(run_name, event, value, metadata)
except _OutOfSpaceError:
raise RuntimeError("add_event failed despite flush")
def _add_event_internal(
self,
run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
):
self._num_values += 1
time_series_data_proto = self._run_to_tag_to_time_series_data[run_name].get(
value.tag
)
if time_series_data_proto is None:
time_series_data_proto = self._create_time_series_data(
run_name, value.tag, metadata
)
self._create_point(run_name, time_series_data_proto, event, value, metadata)
def flush(self):
"""Sends the active request after removing empty runs and tags.
Starts a new, empty active request.
"""
request = self._request
has_data = False
for (
run_name,
tag_to_time_series_data,
) in self._run_to_tag_to_time_series_data.items():
r = tensorboard_service.WriteTensorboardRunDataRequest(
tensorboard_run=self._one_platform_resource_manager.get_run_resource_name(
run_name
)
)
r.time_series_data = list(tag_to_time_series_data.values())
_prune_empty_time_series(r)
if not r.time_series_data:
continue
request.write_run_data_requests.extend([r])
has_data = True
if not has_data:
return
self._rpc_rate_limiter.tick()
with uploader_utils.request_logger(request):
with self._get_tracker():
try:
self._api.write_tensorboard_experiment_data(
tensorboard_experiment=request.tensorboard_experiment,
write_run_data_requests=request.write_run_data_requests,
)
except grpc.RpcError as e:
if (
hasattr(e, "code")
and getattr(e, "code")() == grpc.StatusCode.NOT_FOUND
):
raise ExperimentNotFoundError()
logger.error("Upload call failed with error %s", e)
self._new_request()
def _create_time_series_data(
self, run_name: str, tag_name: str, metadata: tf.compat.v1.SummaryMetadata
) -> tensorboard_data.TimeSeriesData:
"""Adds a time_series for the tag_name, if there's space.
Args:
tag_name: String name of the tag to add (as `value.tag`).
Returns:
The TimeSeriesData in _request proto with the given tag name.
Raises:
_OutOfSpaceError: If adding the tag would exceed the remaining
request budget.
"""
time_series_resource_name = self._one_platform_resource_manager.get_time_series_resource_name(
run_name,
tag_name,
lambda: tensorboard_time_series.TensorboardTimeSeries(
display_name=tag_name,
value_type=self._value_type,
plugin_name=metadata.plugin_data.plugin_name,
plugin_data=metadata.plugin_data.content,
),
)
time_series_data_proto = tensorboard_data.TimeSeriesData(
tensorboard_time_series_id=time_series_resource_name.split("/")[-1],
value_type=self._value_type,
)
self._byte_budget_manager.add_time_series(time_series_data_proto)
self._run_to_tag_to_time_series_data[run_name][
tag_name
] = time_series_data_proto
return time_series_data_proto
def _create_point(
self,
run_name: str,
time_series_proto: tensorboard_data.TimeSeriesData,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
):
"""Adds a scalar point to the given tag, if there's space.
Args:
time_series_proto: TimeSeriesData proto to which to add a point.
event: Enclosing `Event` proto with the step and wall time data.
value: `Summary.Value` proto.
metadata: SummaryMetadata of the event.
Raises:
_OutOfSpaceError: If adding the point would exceed the remaining
request budget.
"""
point = self._create_data_point(run_name, event, value, metadata)
if not self._validate(point, event, value):
return
time_series_proto.values.extend([point])
try:
self._byte_budget_manager.add_point(point)
except _OutOfSpaceError:
time_series_proto.values.pop()
raise
@abc.abstractmethod
def _get_tracker(self) -> ContextManager:
"""
:return: tracker function from upload_tracker.UploadTracker
"""
pass
@property
@classmethod
@abc.abstractmethod
def _value_type(cls,) -> tensorboard_time_series.TensorboardTimeSeries.ValueType:
"""
:return: Value type of the time series.
"""
pass
@abc.abstractmethod
def _create_data_point(
self,
run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
) -> tensorboard_data.TimeSeriesDataPoint:
"""
Creates data point protos for sending to the OnePlatform API.
"""
pass
def _validate(
self,
point: tensorboard_data.TimeSeriesDataPoint,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
):
"""
Validations performed before including the data point to be sent to the
OnePlatform API.
"""
return True
class _ScalarBatchedRequestSender(_BaseBatchedRequestSender):
"""Helper class for building requests that fit under a size limit.
This class accumulates a current request. `add_event(...)` may or may not
send the request (and start a new one). After all `add_event(...)` calls
are complete, a final call to `flush()` is needed to send the final request.
This class is not threadsafe. Use external synchronization if calling its
methods concurrently.
"""
_value_type = tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
def __init__(
self,