Skip to content

Commit

Permalink
fix: Import error for cloud_profiler (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkovalski committed Dec 1, 2021
1 parent da747b5 commit 0f124e9
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 40 deletions.
Expand Up @@ -15,13 +15,7 @@
# limitations under the License.
#

try:
import google.cloud.aiplatform.training_utils.cloud_profiler.initializer as initializer
except ImportError as err:
raise ImportError(
"Could not load the cloud profiler. To use the profiler, "
'install the SDK using "pip install google-cloud-aiplatform[cloud-profiler]"'
) from err
from google.cloud.aiplatform.training_utils.cloud_profiler import initializer

"""
Initialize the cloud profiler for tensorflow.
Expand Down
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import_error_msg = (
"Could not load the cloud profiler. To use the profiler, "
"install the SDK using 'pip install google-cloud-aiplatform[cloud-profiler]'"
)
Expand Up @@ -18,7 +18,14 @@
import logging
import threading
from typing import Optional, Type
from werkzeug import serving

from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils

try:
from werkzeug import serving
except ImportError as err:
raise ImportError(cloud_profiler_utils.import_error_msg) from err


from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.training_utils.cloud_profiler import webserver
Expand All @@ -27,6 +34,7 @@
tf_profiler,
)


# Mapping of available plugins to use
_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler}

Expand Down
Expand Up @@ -17,14 +17,23 @@

"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""

from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils

try:
import tensorflow as tf
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
except ImportError as err:
raise ImportError(cloud_profiler_utils.import_error_msg) from err

import argparse
from collections import namedtuple
import importlib.util
import json
import logging
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
from typing import Callable, Dict, Optional
from urllib import parse

import tensorboard.plugins.base_plugin as tensorboard_base_plugin
from werkzeug import Response

from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
Expand Down Expand Up @@ -54,8 +63,6 @@ def _get_tf_versioning() -> Optional[Version]:
Returns:
A version object if finding the version was successful, None otherwise.
"""
import tensorflow as tf

version = tf.__version__

versioning = version.split(".")
Expand Down Expand Up @@ -269,8 +276,6 @@ class TFProfiler(base_plugin.BasePlugin):

def __init__(self):
"""Build a TFProfiler object."""
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin

context = _create_profiling_context()
self._profile_request_sender: profile_uploader.ProfileRequestSender = tensorboard_api.create_profile_request_sender()
self._profile_plugin: ProfilePlugin = ProfilePlugin(context)
Expand Down Expand Up @@ -317,20 +322,7 @@ def capture_profile_wrapper(

@staticmethod
def setup() -> None:
"""Sets up the plugin.
Raises:
ImportError: Tensorflow could not be imported.
"""
try:
import tensorflow as tf
except ImportError as err:
raise ImportError(
"Could not import tensorflow for profile usage. "
"To use profiler, install the SDK using "
'"pip install google-cloud-aiplatform[cloud_profiler]"'
) from err

"""Sets up the plugin."""
tf.profiler.experimental.server.start(
int(environment_variables.tf_profiler_port)
)
Expand Down
8 changes: 6 additions & 2 deletions setup.py
Expand Up @@ -36,7 +36,11 @@
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
metadata_extra_require = ["pandas >= 1.0.0"]
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
profiler_extra_require = ["tensorboard-plugin-profile", "tensorflow >=2.4.0"]
profiler_extra_require = [
"tensorboard-plugin-profile >= 2.4.0",
"werkzeug >= 2.0.0",
"tensorflow >=2.4.0",
]

full_extra_require = list(
set(tensorboard_extra_require + metadata_extra_require + xai_extra_require)
Expand Down Expand Up @@ -84,7 +88,7 @@
"tensorboard": tensorboard_extra_require,
"testing": testing_extra_require,
"xai": xai_extra_require,
"cloud_profiler": profiler_extra_require,
"cloud-profiler": profiler_extra_require,
},
python_requires=">=3.6",
scripts=[],
Expand Down
35 changes: 24 additions & 11 deletions tests/unit/aiplatform/test_cloud_profiler.py
Expand Up @@ -15,9 +15,9 @@
# limitations under the License.
#

from importlib import reload
import importlib.util
import json
import sys
import threading
from typing import List, Optional

Expand Down Expand Up @@ -75,6 +75,10 @@ def _create_mock_plugin(
return mock_plugin


def _find_child_modules(root_module):
return [module for module in sys.modules.keys() if module.startswith(root_module)]


@pytest.fixture
def tf_profile_plugin_mock():
"""Mock the tensorboard profile plugin"""
Expand Down Expand Up @@ -203,10 +207,6 @@ def testSetup(self):

assert server_mock.call_count == 1

def testSetupRaiseImportError(self):
with mock.patch.dict("sys.modules", {"tensorflow": None}):
self.assertRaises(ImportError, TFProfiler.setup)

def testPostSetupChecksFail(self):
tf_profiler.environment_variables.cluster_spec = {}
assert not TFProfiler.post_setup_check()
Expand Down Expand Up @@ -359,13 +359,26 @@ def start_response(status, headers):

# Initializer tests
class TestInitializer(unittest.TestCase):
# Tests for building the plugin
def test_init_failed_import(self):
with mock.patch.dict(
"sys.modules",
{"google.cloud.aiplatform.training_utils.cloud_profiler.initializer": None},
def testImportError(self):
# Unloads any of the cloud profiler sub-modules
for mod in _find_child_modules(
"google.cloud.aiplatform.training_utils.cloud_profiler"
):
self.assertRaises(ImportError, reload, training_utils.cloud_profiler)
del sys.modules[mod]

# Modules to be mocked out
for mock_module in [
"tensorflow",
"tensorboard_plugin_profile.profile_plugin",
"werkzeug",
]:
with self.subTest():
with mock.patch.dict("sys.modules", {mock_module: None}):
with self.assertRaises(ImportError) as cm:
importlib.import_module(
"google.cloud.aiplatform.training_utils.cloud_profiler"
)
assert "Could not load the cloud profiler" in cm.exception.msg

def test_build_plugin_fail_initialize(self):
plugin = _create_mock_plugin()
Expand Down

0 comments on commit 0f124e9

Please sign in to comment.