forked from googleapis/python-aiplatform
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf_profiler.py
350 lines (264 loc) · 10.9 KB
/
tf_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""
from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
try:
import tensorflow as tf
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
except ImportError as err:
raise ImportError(cloud_profiler_utils.import_error_msg) from err
import argparse
from collections import namedtuple
import importlib.util
import json
import logging
from typing import Callable, Dict, Optional
from urllib import parse
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
from werkzeug import Response
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
tensorboard_api,
)
# TF verison information.
Version = namedtuple("Version", ["major", "minor", "patch"])
logger = logging.Logger("tf-profiler")
_BASE_TB_ENV_WARNING = (
"To set this environment variable, run your training with the 'tensorboard' "
"option. For more information on how to run with training with tensorboard, visit "
"https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training"
)
def _get_tf_versioning() -> Optional[Version]:
"""Convert version string to a Version namedtuple for ease of parsing.
Returns:
A version object if finding the version was successful, None otherwise.
"""
version = tf.__version__
versioning = version.split(".")
if len(versioning) != 3:
return
return Version(int(versioning[0]), int(versioning[1]), int(versioning[2]))
def _is_compatible_version(version: Version) -> bool:
"""Check if version is compatible with tf profiling.
Profiling plugin is available to be used for version >= 2.4.0.
While the profiler is available in 2.2.0 >=, some additional dependencies
that are included in 2.4.0 >= are also needed for the tensorboard-plugin-profile.
Profiler:
https://www.tensorflow.org/guide/profiler
Required commit for tensorboard-plugin-profile:
https://github.com/tensorflow/tensorflow/commit/8b9c207242db515daef033e74d69ea5d8e023dc6
Args:
version (Version):
Required. `Verison` of tensorflow.
Returns:
Bool indicating wheter version is compatible with profiler.
"""
return version.major >= 2 and version.minor >= 4
def _check_tf() -> bool:
"""Check whether all the tensorflow prereqs are met.
Returns:
True if all requirements met, False otherwise.
"""
# Check tf is installed
if importlib.util.find_spec("tensorflow") is None:
logger.warning("Tensorflow not installed, cannot initialize profiling plugin")
return False
# Check tensorflow version
version = _get_tf_versioning()
if version is None:
logger.warning(
"Could not find major, minor, and patch versions of tensorflow. Version found: %s",
version,
)
return False
# Check compatibility, introduced in tensorflow >= 2.2.0
if not _is_compatible_version(version):
logger.warning(
"Version %s is incompatible with tf profiler."
"To use the profiler, choose a version >= 2.2.0",
"%s.%s.%s" % (version.major, version.minor, version.patch),
)
return False
# Check for the tf profiler plugin
if importlib.util.find_spec("tensorboard_plugin_profile") is None:
logger.warning(
"Could not import tensorboard_plugin_profile, will not run tf profiling service"
)
return False
return True
def _create_profiling_context() -> tensorboard_base_plugin.TBContext:
"""Creates the base context needed for TB Profiler.
Returns:
An initialized `TBContext`.
"""
context_flags = argparse.Namespace(master_tpu_unsecure_channel=None)
context = tensorboard_base_plugin.TBContext(
logdir=environment_variables.tensorboard_log_dir,
multiplexer=None,
flags=context_flags,
)
return context
def _host_to_grpc(hostname: str) -> str:
"""Format a hostname to a grpc address.
Args:
hostname (str):
Required. Address in form: `{hostname}:{port}`
Returns:
Address in form of: 'grpc://{hostname}:{port}'
"""
return (
"grpc://"
+ "".join(hostname.split(":")[:-1])
+ ":"
+ environment_variables.tf_profiler_port
)
def _get_hostnames() -> Optional[str]:
"""Get the hostnames for all servers running.
Returns:
A host formatted by `_host_to_grpc` if obtaining the cluster spec
is successful, None otherwise.
"""
cluster_spec = environment_variables.cluster_spec
if cluster_spec is None:
return
cluster = cluster_spec.get("cluster", "")
if not cluster:
return
hostnames = []
for value in cluster.values():
hostnames.extend(value)
return ",".join([_host_to_grpc(x) for x in hostnames])
def _update_environ(environ: wsgi_types.Environment) -> bool:
"""Add parameters to the query that are retrieved from training side.
Args:
environ (wsgi_types.Environment):
Required. The WSGI Environment.
Returns:
Whether the environment was successfully updated.
"""
hosts = _get_hostnames()
if hosts is None:
return False
query_dict = {}
query_dict["service_addr"] = hosts
# Update service address and worker list
# Use parse_qsl and then convert list to dictionary so we can update
# attributes
prev_query_string = dict(parse.parse_qsl(environ["QUERY_STRING"]))
prev_query_string.update(query_dict)
environ["QUERY_STRING"] = parse.urlencode(prev_query_string)
return True
def warn_tensorboard_env_var(var_name: str):
"""Warns if a tensorboard related environment variable is missing.
Args:
var_name (str):
Required. The name of the missing environment variable.
"""
logging.warning(
f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING
)
def _check_env_vars() -> bool:
"""Determine whether the correct environment variables are set.
Returns:
bool indicating all necessary variables are set.
"""
# The below are tensorboard specific environment variables.
if environment_variables.tf_profiler_port is None:
warn_tensorboard_env_var("AIP_TF_PROFILER_PORT")
return False
if environment_variables.tensorboard_log_dir is None:
warn_tensorboard_env_var("AIP_TENSORBOARD_LOG_DIR")
return False
if environment_variables.tensorboard_api_uri is None:
warn_tensorboard_env_var("AIP_TENSORBOARD_API_URI")
return False
if environment_variables.tensorboard_resource_name is None:
warn_tensorboard_env_var("AIP_TENSORBOARD_RESOURCE_NAME")
return False
# These environment variables are not tensorboard related, they are
# variables set for any Vertex training run.
cluster_spec = environment_variables.cluster_spec
if cluster_spec is None:
logger.warning("Environment variable `CLUSTER_SPEC` is not set.")
return False
if environment_variables.cloud_ml_job_id is None:
logger.warning("Environment variable `CLOUD_ML_JOB_ID` is not set")
return False
return True
class TFProfiler(base_plugin.BasePlugin):
"""Handler for Tensorflow Profiling."""
PLUGIN_NAME = "profile"
def __init__(self):
"""Build a TFProfiler object."""
context = _create_profiling_context()
self._profile_request_sender: profile_uploader.ProfileRequestSender = tensorboard_api.create_profile_request_sender()
self._profile_plugin: ProfilePlugin = ProfilePlugin(context)
def get_routes(
self,
) -> Dict[str, Callable[[Dict[str, str], Callable[..., None]], Response]]:
"""List of routes to serve.
Returns:
A callable that takes an werkzeug env and start response and returns a response.
"""
return {"/capture_profile": self.capture_profile_wrapper}
# Define routes below
def capture_profile_wrapper(
self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse
) -> Response:
"""Take a request from tensorboard.gcp and run the profiling for the available servers.
Args:
environ (wsgi_types.Environment):
Required. The WSGI environment.
start_response (wsgi_types.StartResponse):
Required. The response callable provided by the WSGI server.
Returns:
A response iterable.
"""
# The service address (localhost) and worker list are populated locally
if not _update_environ(environ):
err = {"error": "Could not parse the environ: %s"}
return Response(
json.dumps(err), content_type="application/json", status=500
)
response = self._profile_plugin.capture_route(environ, start_response)
self._profile_request_sender.send_request("")
return response
# End routes
@staticmethod
def setup() -> None:
"""Sets up the plugin."""
tf.profiler.experimental.server.start(
int(environment_variables.tf_profiler_port)
)
@staticmethod
def post_setup_check() -> bool:
"""Only chief and task 0 should run the webserver."""
cluster_spec = environment_variables.cluster_spec
task_type = cluster_spec.get("task", {}).get("type", "")
task_index = cluster_spec.get("task", {}).get("index", -1)
return task_type in {"workerpool0", "chief"} and task_index == 0
@staticmethod
def can_initialize() -> bool:
"""Check that we can use the TF Profiler plugin.
This function checks a number of dependencies for the plugin to ensure we have the
right packages installed, the necessary versions, and the correct environment variables set.
Returns:
True if can initialize, False otherwise.
"""
return _check_env_vars() and _check_tf()