From 857f63d475463189ebb89b25d0ca08d9544c3bf3 Mon Sep 17 00:00:00 2001 From: sina chavoshi Date: Mon, 3 May 2021 06:57:16 -0700 Subject: [PATCH] fix(aiplatform): Fix doc formatting (#359) --- google/cloud/aiplatform/base.py | 44 +++++---- .../cloud/aiplatform/datasets/_datasources.py | 15 +-- google/cloud/aiplatform/datasets/dataset.py | 4 +- .../aiplatform/datasets/image_dataset.py | 7 +- .../aiplatform/datasets/tabular_dataset.py | 3 +- .../cloud/aiplatform/datasets/text_dataset.py | 7 +- .../aiplatform/datasets/video_dataset.py | 7 +- google/cloud/aiplatform/initializer.py | 8 +- google/cloud/aiplatform/jobs.py | 30 +++--- google/cloud/aiplatform/models.py | 56 +++++------ google/cloud/aiplatform/training_jobs.py | 95 +++++++++---------- google/cloud/aiplatform/training_utils.py | 2 +- google/cloud/aiplatform/utils.py | 50 +++++----- 13 files changed, 165 insertions(+), 163 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 907397b7e8..f46db9c47e 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -94,7 +94,6 @@ def log_create_complete( resource (proto.Message): AI Platform Resourc proto.Message variable_name (str): Name of variable to use for code snippet - """ self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}") self._logger.info(f"To use this {cls.__name__} in another session:") @@ -181,7 +180,8 @@ def _raise_future_exception(self): raise self._exception def _complete_future(self, future: futures.Future): - """Checks for exception of future and removes the pointer if it's still latest. + """Checks for exception of future and removes the pointer if it's still + latest. Args: future (futures.Future): Required. A future to complete. @@ -215,13 +215,14 @@ def wait(self): @property def _latest_future(self) -> Optional[futures.Future]: - """Get the latest future if it exists""" + """Get the latest future if it exists.""" with self.__latest_future_lock: return self.__latest_future @_latest_future.setter def _latest_future(self, future: Optional[futures.Future]): - """Optionally set the latest future and add a complete_future callback.""" + """Optionally set the latest future and add a complete_future + callback.""" with self.__latest_future_lock: self.__latest_future = future if future: @@ -260,7 +261,8 @@ def wait_for_dependencies_and_invoke( kwargs: Dict[str, Any], internal_callbacks: Iterable[Callable[[Any], Any]], ) -> Any: - """Wrapper method to wait on any dependencies before submitting method. + """Wrapper method to wait on any dependencies before submitting + method. Args: deps (Sequence[futures.Future]): @@ -272,7 +274,6 @@ def wait_for_dependencies_and_invoke( Required. The keyword arguments to call the method with. internal_callbacks: (Callable[[Any], Any]): Callbacks that take the result of method. - """ for future in set(deps): @@ -342,12 +343,14 @@ def wait_for_dependencies_and_invoke( @classmethod @abc.abstractmethod def _empty_constructor(cls) -> "FutureManager": - """Should construct object with all non FutureManager attributes as None""" + """Should construct object with all non FutureManager attributes as + None.""" pass @abc.abstractmethod def _sync_object_with_future_result(self, result: "FutureManager"): - """Should sync the object from _empty_constructor with result of future.""" + """Should sync the object from _empty_constructor with result of + future.""" def __repr__(self) -> str: if self._exception: @@ -375,7 +378,8 @@ class AiPlatformResourceNoun(metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod def client_class(cls) -> Type[utils.AiPlatformServiceClientWithOverride]: - """Client class required to interact with resource with optional overrides.""" + """Client class required to interact with resource with optional + overrides.""" pass @property @@ -388,7 +392,8 @@ def _is_client_prediction_client(cls) -> bool: @property @abc.abstractmethod def _getter_method(cls) -> str: - """Name of getter method of client class for retrieving the resource.""" + """Name of getter method of client class for retrieving the + resource.""" pass @property @@ -400,7 +405,7 @@ def _delete_method(cls) -> str: @property @abc.abstractmethod def _resource_noun(cls) -> str: - """Resource noun""" + """Resource noun.""" pass def __init__( @@ -547,7 +552,8 @@ def optional_sync( return_input_arg: Optional[str] = None, bind_future_to_self: bool = True, ): - """Decorator for AiPlatformResourceNounWithFutureManager with optional sync support. + """Decorator for AiPlatformResourceNounWithFutureManager with optional sync + support. Methods with this decorator should include a "sync" argument that defaults to True. If called with sync=False this decorator will launch the method as a @@ -681,7 +687,8 @@ def wrapper(*args, **kwargs): class AiPlatformResourceNounWithFutureManager(AiPlatformResourceNoun, FutureManager): - """Allows optional asynchronous calls to this AI Platform Resource Nouns.""" + """Allows optional asynchronous calls to this AI Platform Resource + Nouns.""" def __init__( self, @@ -816,7 +823,8 @@ def _list( credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: """Private method to list all instances of this AI Platform Resource, - takes a `cls_filter` arg to filter to a particular SDK resource subclass. + takes a `cls_filter` arg to filter to a particular SDK resource + subclass. Args: cls_filter (Callable[[proto.Message], bool]): @@ -884,8 +892,9 @@ def _list_with_local_order( credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: """Private method to list all instances of this AI Platform Resource, - takes a `cls_filter` arg to filter to a particular SDK resource subclass. - Provides client-side sorting when a list API doesn't support `order_by`. + takes a `cls_filter` arg to filter to a particular SDK resource + subclass. Provides client-side sorting when a list API doesn't support + `order_by`. Args: cls_filter (Callable[[proto.Message], bool]): @@ -986,7 +995,8 @@ def list( @optional_sync() def delete(self, sync: bool = True) -> None: - """Deletes this AI Platform resource. WARNING: This deletion is permament. + """Deletes this AI Platform resource. WARNING: This deletion is + permament. Args: sync (bool): diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py index eefd1b04fd..23a89cc157 100644 --- a/google/cloud/aiplatform/datasets/_datasources.py +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -26,7 +26,7 @@ class Datasource(abc.ABC): - """An abstract class that sets dataset_metadata""" + """An abstract class that sets dataset_metadata.""" @property @abc.abstractmethod @@ -36,7 +36,7 @@ def dataset_metadata(self): class DatasourceImportable(abc.ABC): - """An abstract class that sets import_data_config""" + """An abstract class that sets import_data_config.""" @property @abc.abstractmethod @@ -46,14 +46,14 @@ def import_data_config(self): class TabularDatasource(Datasource): - """Datasource for creating a tabular dataset for AI Platform""" + """Datasource for creating a tabular dataset for AI Platform.""" def __init__( self, gcs_source: Optional[Union[str, Sequence[str]]] = None, bq_source: Optional[str] = None, ): - """Creates a tabular datasource + """Creates a tabular datasource. Args: gcs_source (Union[str, Sequence[str]]): @@ -99,7 +99,7 @@ def dataset_metadata(self) -> Optional[Dict]: class NonTabularDatasource(Datasource): - """Datasource for creating an empty non-tabular dataset for AI Platform""" + """Datasource for creating an empty non-tabular dataset for AI Platform.""" @property def dataset_metadata(self) -> Optional[Dict]: @@ -107,7 +107,8 @@ def dataset_metadata(self) -> Optional[Dict]: class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable): - """Datasource for creating a non-tabular dataset for AI Platform and importing data to the dataset""" + """Datasource for creating a non-tabular dataset for AI Platform and + importing data to the dataset.""" def __init__( self, @@ -115,7 +116,7 @@ def __init__( import_schema_uri: str, data_item_labels: Optional[Dict] = None, ): - """Creates a non-tabular datasource + """Creates a non-tabular datasource. Args: gcs_source (Union[str, Sequence[str]]): diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 922ce8930b..4bb98cbd77 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -36,7 +36,7 @@ class _Dataset(base.AiPlatformResourceNounWithFutureManager): - """Managed dataset resource for AI Platform""" + """Managed dataset resource for AI Platform.""" client_class = utils.DatasetClientWithOverride _is_client_prediction_client = False @@ -70,7 +70,6 @@ def __init__( credentials (auth_credentials.Credentials): Custom credentials to use to upload this model. Overrides credentials set in aiplatform.init. - """ super().__init__( @@ -195,7 +194,6 @@ def create( Returns: dataset (Dataset): Instantiated representation of the managed dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py index cea13014d8..fdc6c99a79 100644 --- a/google/cloud/aiplatform/datasets/image_dataset.py +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -27,7 +27,7 @@ class ImageDataset(datasets._Dataset): - """Managed image dataset resource for AI Platform""" + """Managed image dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.image, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "ImageDataset": - """Creates a new image dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new image dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Args: display_name (str): @@ -114,7 +114,6 @@ def create( Returns: image_dataset (ImageDataset): Instantiated representation of the managed image dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 3dd217aad7..06ba4a3394 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -27,7 +27,7 @@ class TabularDataset(datasets._Dataset): - """Managed tabular dataset resource for AI Platform""" + """Managed tabular dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.tabular, @@ -95,7 +95,6 @@ def create( Returns: tabular_dataset (TabularDataset): Instantiated representation of the managed tabular dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py index 2b791e5c82..568edc9e47 100644 --- a/google/cloud/aiplatform/datasets/text_dataset.py +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -27,7 +27,7 @@ class TextDataset(datasets._Dataset): - """Managed text dataset resource for AI Platform""" + """Managed text dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.text, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "TextDataset": - """Creates a new text dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new text dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Example Usage: ds = aiplatform.TextDataset.create( @@ -121,7 +121,6 @@ def create( Returns: text_dataset (TextDataset): Instantiated representation of the managed text dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py index c50298f99a..4115365c64 100644 --- a/google/cloud/aiplatform/datasets/video_dataset.py +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -27,7 +27,7 @@ class VideoDataset(datasets._Dataset): - """Managed video dataset resource for AI Platform""" + """Managed video dataset resource for AI Platform.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.video, @@ -47,8 +47,8 @@ def create( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> "VideoDataset": - """Creates a new video dataset and optionally imports data into dataset when - source and import_schema_uri are passed. + """Creates a new video dataset and optionally imports data into dataset + when source and import_schema_uri are passed. Args: display_name (str): @@ -114,7 +114,6 @@ def create( Returns: video_dataset (VideoDataset): Instantiated representation of the managed video dataset resource. - """ utils.validate_display_name(display_name) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index b84a006d02..41a3b06d7f 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -107,8 +107,9 @@ def get_encryption_spec( gca_encryption_spec_v1beta1.EncryptionSpec, ] ]: - """Creates a gca_encryption_spec.EncryptionSpec instance from the given key name. - If the provided key name is None, it uses the default key name if provided. + """Creates a gca_encryption_spec.EncryptionSpec instance from the given + key name. If the provided key name is None, it uses the default key + name if provided. Args: encryption_spec_key_name (Optional[str]): The default encryption key name to use when creating resources. @@ -241,7 +242,8 @@ def create_client( location_override: Optional[str] = None, prediction_client: bool = False, ) -> utils.AiPlatformServiceClientWithOverride: - """Instantiates a given AiPlatformServiceClient with optional overrides. + """Instantiates a given AiPlatformServiceClient with optional + overrides. Args: client_class (utils.AiPlatformServiceClientWithOverride): diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index b428240555..ee6d46dde9 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -64,8 +64,7 @@ class _Job(base.AiPlatformResourceNounWithFutureManager): - """ - Class that represents a general Job resource in AI Platform (Unified). + """Class that represents a general Job resource in AI Platform (Unified). Cannot be directly instantiated. Serves as base class to specific Job types, i.e. BatchPredictionJob or @@ -89,8 +88,8 @@ def __init__( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ): - """ - Retrives Job subclass resource by calling a subclass-specific getter method. + """Retrives Job subclass resource by calling a subclass-specific getter + method. Args: job_name (str): @@ -142,7 +141,8 @@ def _cancel_method(cls) -> str: pass def _dashboard_uri(self) -> Optional[str]: - """Helper method to compose the dashboard uri where job can be viewed.""" + """Helper method to compose the dashboard uri where job can be + viewed.""" fields = utils.extract_fields_from_resource_name(self.resource_name) url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}" return url @@ -152,7 +152,6 @@ def _block_until_complete(self): Raises: RuntimeError: If job failed or cancelled. - """ # Used these numbers so failures surface fast @@ -232,8 +231,11 @@ def list( ) def cancel(self) -> None: - """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` - property to verify if cancellation was successful.""" + """Cancels this Job. + + Success of cancellation is not guaranteed. Use `Job.state` + property to verify if cancellation was successful. + """ _LOGGER.log_action_start_against_resource("Cancelling", "run", self) getattr(self.api_client, self._cancel_method)(name=self.resource_name) @@ -255,8 +257,8 @@ def __init__( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ): - """ - Retrieves a BatchPredictionJob resource and instantiates its representation. + """Retrieves a BatchPredictionJob resource and instantiates its + representation. Args: batch_prediction_job_name (str): @@ -463,7 +465,6 @@ def create( Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. - """ utils.validate_display_name(job_display_name) @@ -655,7 +656,6 @@ def _create( If no or multiple source or destinations are provided. Also, if provided instances_format or predictions_format are not supported by AI Platform. - """ # select v1beta1 if explain else use default v1 if generate_explanation: @@ -687,9 +687,9 @@ def _create( def iter_outputs( self, bq_max_results: Optional[int] = 100 ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: - """Returns an Iterable object to traverse the output files, either a list - of GCS Blobs or a BigQuery RowIterator depending on the output config set - when the BatchPredictionJob was created. + """Returns an Iterable object to traverse the output files, either a + list of GCS Blobs or a BigQuery RowIterator depending on the output + config set when the BatchPredictionJob was created. Args: bq_max_results: Optional[int] = 100 diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 5a139b90aa..ea8b154a20 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -217,8 +217,8 @@ def _create( encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, sync=True, ) -> "Endpoint": - """ - Creates a new endpoint by calling the API client. + """Creates a new endpoint by calling the API client. + Args: api_client (EndpointServiceClient): Required. An instance of EndpointServiceClient with the correct @@ -296,9 +296,8 @@ def _create( def _allocate_traffic( traffic_split: Dict[str, int], traffic_percentage: int, ) -> Dict[str, int]: - """ - Allocates desired traffic to new deployed model and scales traffic of - older deployed models. + """Allocates desired traffic to new deployed model and scales traffic + of older deployed models. Args: traffic_split (Dict[str, int]): @@ -333,9 +332,8 @@ def _allocate_traffic( def _unallocate_traffic( traffic_split: Dict[str, int], deployed_model_id: str, ) -> Dict[str, int]: - """ - Sets deployed model id's traffic to 0 and scales the traffic of other - deployed models. + """Sets deployed model id's traffic to 0 and scales the traffic of + other deployed models. Args: traffic_split (Dict[str, int]): @@ -431,11 +429,11 @@ def _validate_deploy_args( For more details, see `Ref docs ` Raises: - ValueError if Min or Max replica is negative. Traffic percentage > 100 or - < 0. Or if traffic_split does not sum to 100. + ValueError: if Min or Max replica is negative. Traffic percentage > 100 or + < 0. Or if traffic_split does not sum to 100. - ValueError if either explanation_metadata or explanation_parameters - but not both are specified. + ValueError: if either explanation_metadata or explanation_parameters + but not both are specified. """ if min_replica_count < 0: raise ValueError("Min replica cannot be negative.") @@ -483,8 +481,7 @@ def deploy( metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: - """ - Deploys a Model to the Endpoint. + """Deploys a Model to the Endpoint. Args: model (aiplatform.Model): @@ -602,8 +599,7 @@ def _deploy( metadata: Optional[Sequence[Tuple[str, str]]] = (), sync=True, ) -> None: - """ - Deploys a Model to the Endpoint. + """Deploys a Model to the Endpoint. Args: model (aiplatform.Model): @@ -795,9 +791,9 @@ def _deploy_call( will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. Raises: - ValueError if there is not current traffic split and traffic percentage + ValueError: If there is not current traffic split and traffic percentage is not 0 or 100. - ValueError if only `explanation_metadata` or `explanation_parameters` + ValueError: If only `explanation_metadata` or `explanation_parameters` is specified. """ @@ -987,7 +983,8 @@ def _instantiate_prediction_client( credentials: Optional[auth_credentials.Credentials] = None, ) -> utils.PredictionClientWithOverride: - """Helper method to instantiates prediction client with optional overrides for this endpoint. + """Helper method to instantiates prediction client with optional + overrides for this endpoint. Args: location (str): The location of this endpoint. @@ -1030,7 +1027,6 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict ``parameters_schema_uri``. Returns: prediction: Prediction with returned predictions and Model Id. - """ self.wait() @@ -1282,7 +1278,8 @@ def upload( encryption_spec_key_name: Optional[str] = None, sync=True, ) -> "Model": - """Uploads a model and returns a Model representing the uploaded Model resource. + """Uploads a model and returns a Model representing the uploaded Model + resource. Example usage: @@ -1415,7 +1412,7 @@ def upload( Returns: model: Instantiated representation of the uploaded model resource. Raises: - ValueError if only `explanation_metadata` or `explanation_parameters` + ValueError: If only `explanation_metadata` or `explanation_parameters` is specified. """ utils.validate_display_name(display_name) @@ -1523,8 +1520,7 @@ def deploy( encryption_spec_key_name: Optional[str] = None, sync=True, ) -> Endpoint: - """ - Deploys model to endpoint. Endpoint will be created if unspecified. + """Deploys model to endpoint. Endpoint will be created if unspecified. Args: endpoint ("Endpoint"): @@ -1608,7 +1604,6 @@ def deploy( Returns: endpoint ("Endpoint"): Endpoint with the deployed model. - """ Endpoint._validate_deploy_args( @@ -1660,8 +1655,7 @@ def _deploy( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> Endpoint: - """ - Deploys model to endpoint. Endpoint will be created if unspecified. + """Deploys model to endpoint. Endpoint will be created if unspecified. Args: endpoint ("Endpoint"): @@ -1807,9 +1801,10 @@ def batch_predict( encryption_spec_key_name: Optional[str] = None, sync: bool = True, ) -> jobs.BatchPredictionJob: - """Creates a batch prediction job using this Model and outputs prediction - results to the provided destination prefix in the specified - `predictions_format`. One source and one destination prefix are required. + """Creates a batch prediction job using this Model and outputs + prediction results to the provided destination prefix in the specified + `predictions_format`. One source and one destination prefix are + required. Example usage: @@ -1960,7 +1955,6 @@ def batch_predict( Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. - """ self.wait() diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index fe3cce059e..441f91ca39 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -143,7 +143,7 @@ def __init__( @classmethod @abc.abstractmethod def _supported_training_schemas(cls) -> Tuple[str]: - """List of supported schemas for this training job""" + """List of supported schemas for this training job.""" pass @@ -211,7 +211,10 @@ def _model_upload_fail_string(self) -> str: @abc.abstractmethod def run(self) -> Optional[models.Model]: - """Runs the training job. Should call _run_job internally""" + """Runs the training job. + + Should call _run_job internally + """ pass @staticmethod @@ -530,7 +533,8 @@ def _run_job( return model def _is_waiting_to_run(self) -> bool: - """Returns True if the Job is pending on upstream tasks False otherwise.""" + """Returns True if the Job is pending on upstream tasks False + otherwise.""" self._raise_future_exception() if self._latest_future: _LOGGER.info( @@ -563,7 +567,7 @@ def get_model(self, sync=True) -> models.Model: model: AI Platform Model produced by this training Raises: - RuntimeError if training failed or if a model was not produced by this training. + RuntimeError: If training failed or if a model was not produced by this training. """ self._assert_has_run() @@ -586,7 +590,7 @@ def _force_get_model(self, sync: bool = True) -> models.Model: model: AI Platform Model produced by this training Raises: - RuntimeError if training failed or if a model was not produced by this training. + RuntimeError: If training failed or if a model was not produced by this training. """ model = self._get_model() @@ -603,7 +607,7 @@ def _get_model(self) -> Optional[models.Model]: Model. None otherwise. Raises: - RuntimeError if Training failed. + RuntimeError: If Training failed. """ self._block_until_complete() @@ -662,19 +666,24 @@ def _raise_failure(self): """Helper method to raise failure if TrainingPipeline fails. Raises: - RuntimeError: If training failed.""" + RuntimeError: If training failed. + """ if self._gca_resource.error.code != code_pb2.OK: raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) @property def has_failed(self) -> bool: - """Returns True if training has failed. False otherwise.""" + """Returns True if training has failed. + + False otherwise. + """ self._assert_has_run() return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED def _dashboard_uri(self) -> str: - """Helper method to compose the dashboard uri where training can be viewed.""" + """Helper method to compose the dashboard uri where training can be + viewed.""" fields = utils.extract_fields_from_resource_name(self.resource_name) url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" return url @@ -762,7 +771,7 @@ def cancel(self) -> None: becomes a job with state set to `CANCELLED`. Raises: - RuntimeError if this TrainingJob has not started running. + RuntimeError: If this TrainingJob has not started running. """ if not self._has_run: raise RuntimeError( @@ -838,10 +847,10 @@ def _timestamped_copy_to_gcs( def _get_python_executable() -> str: """Returns Python executable. - Raises: - EnvironmentError if Python executable is not found. Returns: Python executable to use for setuptools packaging. + Raises: + EnvironmentError: If Python executable is not found. """ python_executable = sys.executable @@ -852,7 +861,8 @@ def _get_python_executable() -> str: class _TrainingScriptPythonPackager: - """Converts a Python script into Python package suitable for aiplatform training. + """Converts a Python script into Python package suitable for aiplatform + training. Copies the script to specified location. @@ -879,7 +889,6 @@ class _TrainingScriptPythonPackager: The package after installed can be executed as: python -m aiplatform_custom_trainer_script.task - """ _TRAINER_FOLDER = "trainer" @@ -917,14 +926,15 @@ def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = Non self.requirements = requirements or [] def make_package(self, package_directory: str) -> str: - """Converts script into a Python package suitable for python module execution. + """Converts script into a Python package suitable for python module + execution. Args: package_directory (str): Directory to build package in. Returns: source_distribution_path (str): Path to built package. Raises: - RunTimeError if package creation fails. + RunTimeError: If package creation fails. """ # The root folder to builder the package in package_path = pathlib.Path(package_directory) @@ -1126,7 +1136,6 @@ class _DistributedTrainingSpec(NamedTuple): accelerator_type='NVIDIA_TESLA_K80' ) ) - """ chief_spec: _MachineSpec = _MachineSpec() @@ -1138,7 +1147,8 @@ class _DistributedTrainingSpec(NamedTuple): def pool_specs( self, ) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]: - """Return each pools spec in correct order for AI Platform as a list of dicts. + """Return each pools spec in correct order for AI Platform as a list of + dicts. Also removes specs if they are empty but leaves specs in if there unusual specifications to not break the ordering in AI Platform Training. @@ -1215,8 +1225,7 @@ def chief_worker_pool( class _CustomTrainingJob(_TrainingJob): - """ABC for Custom Training Pipelines.. - """ + """ABC for Custom Training Pipelines..""" _supported_training_schemas = (schema.training_job.definition.custom_task,) @@ -1448,7 +1457,8 @@ def _prepare_and_validate_run( accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", accelerator_count: int = 0, ) -> Tuple[_DistributedTrainingSpec, Optional[gca_model.Model]]: - """Create worker pool specs and managed model as well validating the run. + """Create worker pool specs and managed model as well validating the + run. Args: model_display_name (str): @@ -1473,9 +1483,8 @@ def _prepare_and_validate_run( Worker pools specs and managed model for run. Raises: - RuntimeError if Training job has already been run or model_display_name was - provided but required arguments were not provided in constructor. - + RuntimeError: If Training job has already been run or model_display_name was + provided but required arguments were not provided in constructor. """ if self._is_waiting_to_run(): @@ -1567,8 +1576,8 @@ def _model_upload_fail_string(self) -> str: class CustomTrainingJob(_CustomTrainingJob): """Class to launch a Custom Training Job in AI Platform using a script. - Takes a training implementation as a python script and executes that script - in Cloud AI Platform Training. + Takes a training implementation as a python script and executes that + script in Cloud AI Platform Training. """ def __init__( @@ -2100,7 +2109,8 @@ def _run( class CustomContainerTrainingJob(_CustomTrainingJob): - """Class to launch a Custom Training Job in AI Platform using a Container.""" + """Class to launch a Custom Training Job in AI Platform using a + Container.""" def __init__( self, @@ -2351,14 +2361,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset ( - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ): + dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset]): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environment variables uris: @@ -2461,7 +2464,7 @@ def run( produce an AI Platform Model. Raises: - RuntimeError if Training job has already been run, staging_bucket has not + RuntimeError: If Training job has already been run, staging_bucket has not been set, or model_display_name was provided but required arguments were not provided in constructor. """ @@ -2834,7 +2837,7 @@ def run( produce an AI Platform Model. Raises: - RuntimeError if Training job has already been run or is waiting to run. + RuntimeError: If Training job has already been run or is waiting to run. """ if self._is_waiting_to_run(): @@ -3364,10 +3367,11 @@ def _model_upload_fail_string(self) -> str: class CustomPythonPackageTrainingJob(_CustomTrainingJob): - """Class to launch a Custom Training Job in AI Platform using a Python Package. + """Class to launch a Custom Training Job in AI Platform using a Python + Package. - Takes a training implementation as a python package and executes that package - in Cloud AI Platform Training. + Takes a training implementation as a python package and executes + that package in Cloud AI Platform Training. """ def __init__( @@ -3627,14 +3631,7 @@ def run( of data will be used for training, 10% for validation, and 10% for test. Args: - dataset ( - Union[ - datasets.ImageDataset, - datasets.TabularDataset, - datasets.TextDataset, - datasets.VideoDataset, - ] - ): + dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]): AI Platform to fit this training against. Custom training script should retrieve datasets through passed in environement variables uris: @@ -4308,7 +4305,7 @@ def run( model: The trained AI Platform Model resource. Raises: - RuntimeError if Training job has already been run or is waiting to run. + RuntimeError: If Training job has already been run or is waiting to run. """ if self._is_waiting_to_run(): diff --git a/google/cloud/aiplatform/training_utils.py b/google/cloud/aiplatform/training_utils.py index a93ecaa1ce..fea60c5005 100644 --- a/google/cloud/aiplatform/training_utils.py +++ b/google/cloud/aiplatform/training_utils.py @@ -22,7 +22,7 @@ class EnvironmentVariables: - """Passes on OS' environment variables""" + """Passes on OS' environment variables.""" @property def training_data_uri(self) -> Optional[str]: diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py index 3154d9568d..22991290da 100644 --- a/google/cloud/aiplatform/utils.py +++ b/google/cloud/aiplatform/utils.py @@ -79,7 +79,8 @@ def _match_to_fields(match: Match) -> Optional[Fields]: - """Normalize RegEx groups from resource name pattern Match to class Fields""" + """Normalize RegEx groups from resource name pattern Match to class + Fields.""" if not match: return None @@ -92,15 +93,15 @@ def _match_to_fields(match: Match) -> Optional[Fields]: def validate_id(resource_id: str) -> bool: - """Validate int64 resource ID number""" + """Validate int64 resource ID number.""" return bool(RESOURCE_ID_PATTERN.match(resource_id)) def extract_fields_from_resource_name( resource_name: str, resource_noun: Optional[str] = None ) -> Optional[Fields]: - """Validates and returns extracted fields from a fully-qualified resource name. - Returns None if name is invalid. + """Validates and returns extracted fields from a fully-qualified resource + name. Returns None if name is invalid. Args: resource_name (str): @@ -133,8 +134,7 @@ def full_resource_name( project: Optional[str] = None, location: Optional[str] = None, ) -> str: - """ - Returns fully qualified resource name. + """Returns fully qualified resource name. Args: resource_name (str): @@ -217,7 +217,7 @@ def validate_project(project: str) -> bool: # TODO(b/172932277) verify display name only contains utf-8 chars def validate_display_name(display_name: str): - """Verify display name is at most 128 chars + """Verify display name is at most 128 chars. Args: display_name: display name to verify @@ -253,7 +253,8 @@ def validate_region(region: str) -> bool: def validate_accelerator_type(accelerator_type: str) -> bool: - """Validates user provided accelerator_type string for training and prediction + """Validates user provided accelerator_type string for training and + prediction. Args: accelerator_type (str): @@ -307,7 +308,8 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona class ClientWithOverride: class WrappedClient: - """Wrapper class for client that creates client at API invocation time.""" + """Wrapper class for client that creates client at API invocation + time.""" def __init__( self, @@ -318,14 +320,15 @@ def __init__( ): """Stores parameters needed to instantiate client. - client_class (AiPlatformServiceClient): - Required. Class of the client to use. - client_options (client_options.ClientOptions): - Required. Client options to pass to client. - client_info (gapic_v1.client_info.ClientInfo): - Required. Client info to pass to client. - credentials (auth_credentials.credentials): - Optional. Client credentials to pass to client. + Args: + client_class (AiPlatformServiceClient): + Required. Class of the client to use. + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. """ self._client_class = client_class @@ -367,12 +370,13 @@ def __init__( ): """Stores parameters needed to instantiate client. - client_options (client_options.ClientOptions): - Required. Client options to pass to client. - client_info (gapic_v1.client_info.ClientInfo): - Required. Client info to pass to client. - credentials (auth_credentials.credentials): - Optional. Client credentials to pass to client. + Args: + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. """ self._clients = {