diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py index b297530955..0f6b7f42fa 100644 --- a/google/cloud/aiplatform/datasets/__init__.py +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -16,6 +16,7 @@ # from google.cloud.aiplatform.datasets.dataset import _Dataset +from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset from google.cloud.aiplatform.datasets.image_dataset import ImageDataset @@ -25,6 +26,7 @@ __all__ = ( "_Dataset", + "_ColumnNamesDataset", "TabularDataset", "TimeSeriesDataset", "ImageDataset", diff --git a/google/cloud/aiplatform/datasets/column_names_dataset.py b/google/cloud/aiplatform/datasets/column_names_dataset.py new file mode 100644 index 0000000000..e455642be5 --- /dev/null +++ b/google/cloud/aiplatform/datasets/column_names_dataset.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import csv +import logging +from typing import List, Optional, Set +from google.auth import credentials as auth_credentials + +from google.cloud import bigquery +from google.cloud import storage + +from google.cloud.aiplatform import utils +from google.cloud.aiplatform import datasets + + +class _ColumnNamesDataset(datasets._Dataset): + @property + def column_names(self) -> List[str]: + """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or + Google BigQuery source. + + Returns: + List[str] + A list of columns names + + Raises: + RuntimeError: When no valid source is found. + """ + + self._assert_gca_resource_is_available() + + metadata = self._gca_resource.metadata + + if metadata is None: + raise RuntimeError("No metadata found for dataset") + + input_config = metadata.get("inputConfig") + + if input_config is None: + raise RuntimeError("No inputConfig found for dataset") + + gcs_source = input_config.get("gcsSource") + bq_source = input_config.get("bigquerySource") + + if gcs_source: + gcs_source_uris = gcs_source.get("uri") + + if gcs_source_uris and len(gcs_source_uris) > 0: + # Lexicographically sort the files + gcs_source_uris.sort() + + # Get the first file in sorted list + # TODO(b/193044977): Return as Set instead of List + return list( + self._retrieve_gcs_source_columns( + project=self.project, + gcs_csv_file_path=gcs_source_uris[0], + credentials=self.credentials, + ) + ) + elif bq_source: + bq_table_uri = bq_source.get("uri") + if bq_table_uri: + # TODO(b/193044977): Return as Set instead of List + return list( + self._retrieve_bq_source_columns( + project=self.project, + bq_table_uri=bq_table_uri, + credentials=self.credentials, + ) + ) + + raise RuntimeError("No valid CSV or BigQuery datasource found.") + + @staticmethod + def _retrieve_gcs_source_columns( + project: str, + gcs_csv_file_path: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Set[str]: + """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage + + Example Usage: + + column_names = _retrieve_gcs_source_columns( + "project_id", + "gs://example-bucket/path/to/csv_file" + ) + + # column_names = {"column_1", "column_2"} + + Args: + project (str): + Required. Project to initiate the Google Cloud Storage client with. + gcs_csv_file_path (str): + Required. A full path to a CSV files stored on Google Cloud Storage. + Must include "gs://" prefix. + credentials (auth_credentials.Credentials): + Credentials to use to with GCS Client. + Returns: + Set[str] + A set of columns names in the CSV file. + + Raises: + RuntimeError: When the retrieved CSV file is invalid. + """ + + gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( + gcs_csv_file_path + ) + client = storage.Client(project=project, credentials=credentials) + bucket = client.bucket(gcs_bucket) + blob = bucket.blob(gcs_blob) + + # Incrementally download the CSV file until the header is retrieved + first_new_line_index = -1 + start_index = 0 + increment = 1000 + line = "" + + try: + logger = logging.getLogger("google.resumable_media._helpers") + logging_warning_filter = utils.LoggingFilter(logging.INFO) + logger.addFilter(logging_warning_filter) + + while first_new_line_index == -1: + line += blob.download_as_bytes( + start=start_index, end=start_index + increment - 1 + ).decode("utf-8") + + first_new_line_index = line.find("\n") + start_index += increment + + header_line = line[:first_new_line_index] + + # Split to make it an iterable + header_line = header_line.split("\n")[:1] + + csv_reader = csv.reader(header_line, delimiter=",") + except (ValueError, RuntimeError) as err: + raise RuntimeError( + "There was a problem extracting the headers from the CSV file at '{}': {}".format( + gcs_csv_file_path, err + ) + ) + finally: + logger.removeFilter(logging_warning_filter) + + return set(next(csv_reader)) + + @staticmethod + def _get_bq_schema_field_names_recursively( + schema_field: bigquery.SchemaField, + ) -> Set[str]: + """Retrieve the name for a schema field along with ancestor fields. + Nested schema fields are flattened and concatenated with a ".". + Schema fields with child fields are not included, but the children are. + + Args: + project (str): + Required. Project to initiate the BigQuery client with. + bq_table_uri (str): + Required. A URI to a BigQuery table. + Can include "bq://" prefix but not required. + credentials (auth_credentials.Credentials): + Credentials to use with BQ Client. + + Returns: + Set[str] + A set of columns names in the BigQuery table. + """ + + ancestor_names = { + nested_field_name + for field in schema_field.fields + for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively( + field + ) + } + + # Only return "leaf nodes", basically any field that doesn't have children + if len(ancestor_names) == 0: + return {schema_field.name} + else: + return {f"{schema_field.name}.{name}" for name in ancestor_names} + + @staticmethod + def _retrieve_bq_source_columns( + project: str, + bq_table_uri: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Set[str]: + """Retrieve the column names from a table on Google BigQuery + Nested schema fields are flattened and concatenated with a ".". + Schema fields with child fields are not included, but the children are. + + Example Usage: + + column_names = _retrieve_bq_source_columns( + "project_id", + "bq://project_id.dataset.table" + ) + + # column_names = {"column_1", "column_2", "column_3.nested_field"} + + Args: + project (str): + Required. Project to initiate the BigQuery client with. + bq_table_uri (str): + Required. A URI to a BigQuery table. + Can include "bq://" prefix but not required. + credentials (auth_credentials.Credentials): + Credentials to use with BQ Client. + + Returns: + Set[str] + A set of column names in the BigQuery table. + """ + + # Remove bq:// prefix + prefix = "bq://" + if bq_table_uri.startswith(prefix): + bq_table_uri = bq_table_uri[len(prefix) :] + + client = bigquery.Client(project=project, credentials=credentials) + table = client.get_table(bq_table_uri) + schema = table.schema + + return { + field_name + for field in schema + for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively( + field + ) + } diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 741a2cc643..57ad827b31 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -15,16 +15,10 @@ # limitations under the License. # -import csv -import logging - -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union from google.auth import credentials as auth_credentials -from google.cloud import bigquery -from google.cloud import storage - from google.cloud.aiplatform import datasets from google.cloud.aiplatform.datasets import _datasources from google.cloud.aiplatform import initializer @@ -32,233 +26,13 @@ from google.cloud.aiplatform import utils -class TabularDataset(datasets._Dataset): +class TabularDataset(datasets._ColumnNamesDataset): """Managed tabular dataset resource for Vertex AI.""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( schema.dataset.metadata.tabular, ) - @property - def column_names(self) -> List[str]: - """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or - Google BigQuery source. - - Returns: - List[str] - A list of columns names - - Raises: - RuntimeError: When no valid source is found. - """ - - self._assert_gca_resource_is_available() - - metadata = self._gca_resource.metadata - - if metadata is None: - raise RuntimeError("No metadata found for dataset") - - input_config = metadata.get("inputConfig") - - if input_config is None: - raise RuntimeError("No inputConfig found for dataset") - - gcs_source = input_config.get("gcsSource") - bq_source = input_config.get("bigquerySource") - - if gcs_source: - gcs_source_uris = gcs_source.get("uri") - - if gcs_source_uris and len(gcs_source_uris) > 0: - # Lexicographically sort the files - gcs_source_uris.sort() - - # Get the first file in sorted list - # TODO(b/193044977): Return as Set instead of List - return list( - self._retrieve_gcs_source_columns( - project=self.project, - gcs_csv_file_path=gcs_source_uris[0], - credentials=self.credentials, - ) - ) - elif bq_source: - bq_table_uri = bq_source.get("uri") - if bq_table_uri: - # TODO(b/193044977): Return as Set instead of List - return list( - self._retrieve_bq_source_columns( - project=self.project, - bq_table_uri=bq_table_uri, - credentials=self.credentials, - ) - ) - - raise RuntimeError("No valid CSV or BigQuery datasource found.") - - @staticmethod - def _retrieve_gcs_source_columns( - project: str, - gcs_csv_file_path: str, - credentials: Optional[auth_credentials.Credentials] = None, - ) -> Set[str]: - """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage - - Example Usage: - - column_names = _retrieve_gcs_source_columns( - "project_id", - "gs://example-bucket/path/to/csv_file" - ) - - # column_names = {"column_1", "column_2"} - - Args: - project (str): - Required. Project to initiate the Google Cloud Storage client with. - gcs_csv_file_path (str): - Required. A full path to a CSV files stored on Google Cloud Storage. - Must include "gs://" prefix. - credentials (auth_credentials.Credentials): - Credentials to use to with GCS Client. - Returns: - Set[str] - A set of columns names in the CSV file. - - Raises: - RuntimeError: When the retrieved CSV file is invalid. - """ - - gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( - gcs_csv_file_path - ) - client = storage.Client(project=project, credentials=credentials) - bucket = client.bucket(gcs_bucket) - blob = bucket.blob(gcs_blob) - - # Incrementally download the CSV file until the header is retrieved - first_new_line_index = -1 - start_index = 0 - increment = 1000 - line = "" - - try: - logger = logging.getLogger("google.resumable_media._helpers") - logging_warning_filter = utils.LoggingFilter(logging.INFO) - logger.addFilter(logging_warning_filter) - - while first_new_line_index == -1: - line += blob.download_as_bytes( - start=start_index, end=start_index + increment - 1 - ).decode("utf-8") - - first_new_line_index = line.find("\n") - start_index += increment - - header_line = line[:first_new_line_index] - - # Split to make it an iterable - header_line = header_line.split("\n")[:1] - - csv_reader = csv.reader(header_line, delimiter=",") - except (ValueError, RuntimeError) as err: - raise RuntimeError( - "There was a problem extracting the headers from the CSV file at '{}': {}".format( - gcs_csv_file_path, err - ) - ) - finally: - logger.removeFilter(logging_warning_filter) - - return set(next(csv_reader)) - - @staticmethod - def _get_bq_schema_field_names_recursively( - schema_field: bigquery.SchemaField, - ) -> Set[str]: - """Retrieve the name for a schema field along with ancestor fields. - Nested schema fields are flattened and concatenated with a ".". - Schema fields with child fields are not included, but the children are. - - Args: - project (str): - Required. Project to initiate the BigQuery client with. - bq_table_uri (str): - Required. A URI to a BigQuery table. - Can include "bq://" prefix but not required. - credentials (auth_credentials.Credentials): - Credentials to use with BQ Client. - - Returns: - Set[str] - A set of columns names in the BigQuery table. - """ - - ancestor_names = { - nested_field_name - for field in schema_field.fields - for nested_field_name in TabularDataset._get_bq_schema_field_names_recursively( - field - ) - } - - # Only return "leaf nodes", basically any field that doesn't have children - if len(ancestor_names) == 0: - return {schema_field.name} - else: - return {f"{schema_field.name}.{name}" for name in ancestor_names} - - @staticmethod - def _retrieve_bq_source_columns( - project: str, - bq_table_uri: str, - credentials: Optional[auth_credentials.Credentials] = None, - ) -> Set[str]: - """Retrieve the column names from a table on Google BigQuery - Nested schema fields are flattened and concatenated with a ".". - Schema fields with child fields are not included, but the children are. - - Example Usage: - - column_names = _retrieve_bq_source_columns( - "project_id", - "bq://project_id.dataset.table" - ) - - # column_names = {"column_1", "column_2", "column_3.nested_field"} - - Args: - project (str): - Required. Project to initiate the BigQuery client with. - bq_table_uri (str): - Required. A URI to a BigQuery table. - Can include "bq://" prefix but not required. - credentials (auth_credentials.Credentials): - Credentials to use with BQ Client. - - Returns: - Set[str] - A set of column names in the BigQuery table. - """ - - # Remove bq:// prefix - prefix = "bq://" - if bq_table_uri.startswith(prefix): - bq_table_uri = bq_table_uri[len(prefix) :] - - client = bigquery.Client(project=project, credentials=credentials) - table = client.get_table(bq_table_uri) - schema = table.schema - - return { - field_name - for field in schema - for field_name in TabularDataset._get_bq_schema_field_names_recursively( - field - ) - } - @classmethod def create( cls, diff --git a/google/cloud/aiplatform/datasets/time_series_dataset.py b/google/cloud/aiplatform/datasets/time_series_dataset.py index 5bad36b896..aab96eda90 100644 --- a/google/cloud/aiplatform/datasets/time_series_dataset.py +++ b/google/cloud/aiplatform/datasets/time_series_dataset.py @@ -26,7 +26,7 @@ from google.cloud.aiplatform import utils -class TimeSeriesDataset(datasets._Dataset): +class TimeSeriesDataset(datasets._ColumnNamesDataset): """Managed time series dataset resource for Vertex AI""" _supported_metadata_schema_uris: Optional[Tuple[str]] = ( diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 8d8583f850..9436f19cfe 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -18,7 +18,6 @@ import datetime import time from typing import Dict, List, Optional, Sequence, Tuple, Union -import warnings import abc @@ -42,6 +41,7 @@ from google.cloud.aiplatform.utils import _timestamped_gcs_dir from google.cloud.aiplatform.utils import source_utils from google.cloud.aiplatform.utils import worker_spec_utils +from google.cloud.aiplatform.utils import column_transformations_utils from google.cloud.aiplatform.v1.schema.trainingjob import ( definition_v1 as training_job_inputs, @@ -2997,7 +2997,7 @@ def __init__( optimization_prediction_type: str, optimization_objective: Optional[str] = None, column_specs: Optional[Dict[str, str]] = None, - column_transformations: Optional[Union[Dict, List[Dict]]] = None, + column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, optimization_objective_recall_value: Optional[float] = None, optimization_objective_precision_value: Optional[float] = None, project: Optional[str] = None, @@ -3070,7 +3070,7 @@ def __init__( ignored by the training, except for the targetColumn, which should have no transformations defined on. Only one of column_transformations or column_specs should be passed. - column_transformations (Union[Dict, List[Dict]]): + column_transformations (List[Dict[str, Dict[str, str]]]): Optional. Transformations to apply to the input columns (i.e. columns other than the targetColumn). Each transformation may produce multiple result values from the column's value, and all are used for training. @@ -3136,8 +3136,8 @@ def __init__( Overrides encryption_spec_key_name set in aiplatform.init. - Raises: - ValueError: When both column_transforations and column_specs were passed + Raises: + ValueError: If both column_transformations and column_specs were provided. """ super().__init__( display_name=display_name, @@ -3148,26 +3148,11 @@ def __init__( training_encryption_spec_key_name=training_encryption_spec_key_name, model_encryption_spec_key_name=model_encryption_spec_key_name, ) - # user populated transformations - if column_transformations is not None and column_specs is not None: - raise ValueError( - "Both column_transformations and column_specs were passed. Only one is allowed." - ) - if column_transformations is not None: - self._column_transformations = column_transformations - warnings.simplefilter("always", DeprecationWarning) - warnings.warn( - "consider using column_specs instead. column_transformations will be deprecated in the future.", - DeprecationWarning, - stacklevel=2, - ) - elif column_specs is not None: - self._column_transformations = [ - {transformation: {"column_name": column_name}} - for column_name, transformation in column_specs.items() - ] - else: - self._column_transformations = None + + self._column_transformations = column_transformations_utils.validate_and_get_column_transformations( + column_specs, column_transformations + ) + self._optimization_objective = optimization_objective self._optimization_prediction_type = optimization_prediction_type self._optimization_objective_recall_value = optimization_objective_recall_value @@ -3523,14 +3508,12 @@ def _run( "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations." ) - column_names = [ - column_name - for column_name in dataset.column_names - if column_name != target_column - ] - self._column_transformations = [ - {"auto": {"column_name": column_name}} for column_name in column_names - ] + ( + self._column_transformations, + column_names, + ) = column_transformations_utils.get_default_column_transformations( + dataset=dataset, target_column=target_column + ) _LOGGER.info( "The column transformation of type 'auto' was set for the following columns: %s." @@ -3647,28 +3630,21 @@ class AutoMLForecastingTrainingJob(_TrainingJob): def __init__( self, display_name: str, - labels: Optional[Dict[str, str]] = None, optimization_objective: Optional[str] = None, - column_transformations: Optional[Union[Dict, List[Dict]]] = None, + column_specs: Optional[Dict[str, str]] = None, + column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None, project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + labels: Optional[Dict[str, str]] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, ): """Constructs a AutoML Forecasting Training Job. Args: display_name (str): Required. The user-defined name of this TrainingPipeline. - labels (Dict[str, str]): - Optional. The labels with user-defined metadata to - organize TrainingPipelines. - Label keys and values can be no longer than 64 - characters (Unicode codepoints), can only - contain lowercase letters, numeric characters, - underscores and dashes. International characters - are allowed. - See https://goo.gl/xmQnxf for more information - and examples of labels. optimization_objective (str): Optional. Objective function the model is to be optimized towards. The training process creates a Model that optimizes the value of the objective @@ -3681,15 +3657,29 @@ def __init__( and mean-absolute-error (MAE). "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles. (Set this objective to build quantile forecasts.) - column_transformations (Optional[Union[Dict, List[Dict]]]): + column_specs (Dict[str, str]): + Optional. Alternative to column_transformations where the keys of the dict + are column names and their respective values are one of + AutoMLTabularTrainingJob.column_data_types. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. Only columns with no child + should have a transformation. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + Only one of column_transformations or column_specs should be passed. + column_transformations (List[Dict[str, Dict[str, str]]]): Optional. Transformations to apply to the input columns (i.e. columns other than the targetColumn). Each transformation may produce multiple result values from the column's value, and all are used for training. When creating transformation for BigQuery Struct column, the column - should be flattened using "." as the delimiter. + should be flattened using "." as the delimiter. Only columns with no child + should have a transformation. If an input column has no transformations on it, such a column is ignored by the training, except for the targetColumn, which should have no transformations defined on. + Only one of column_transformations or column_specs should be passed. + Consider using column_specs as column_transformations will be deprecated eventually. project (str): Optional. Project to run training in. Overrides project set in aiplatform.init. location (str): @@ -3697,15 +3687,59 @@ def __init__( credentials (auth_credentials.Credentials): Optional. Custom credentials to use to run call training service. Overrides credentials set in aiplatform.init. + labels (Dict[str, str]): + Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + + Raises: + ValueError: If both column_transformations and column_specs were provided. """ super().__init__( display_name=display_name, - labels=labels, project=project, location=location, credentials=credentials, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, ) - self._column_transformations = column_transformations + + self._column_transformations = column_transformations_utils.validate_and_get_column_transformations( + column_specs, column_transformations + ) + self._optimization_objective = optimization_objective self._additional_experiments = [] @@ -3720,6 +3754,9 @@ def run( forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, predefined_split_column_name: Optional[str] = None, weight_column: Optional[str] = None, time_series_attribute_columns: Optional[List[str]] = None, @@ -3736,8 +3773,25 @@ def run( ) -> models.Model: """Runs the training job and returns a model. - The training data splits are set by default: Roughly 80% will be used for training, - 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (datasets.Dataset): @@ -3896,6 +3950,9 @@ def run( forecast_horizon=forecast_horizon, data_granularity_unit=data_granularity_unit, data_granularity_count=data_granularity_count, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, predefined_split_column_name=predefined_split_column_name, weight_column=weight_column, time_series_attribute_columns=time_series_attribute_columns, @@ -4119,6 +4176,9 @@ def _run( forecast_horizon: int, data_granularity_unit: str, data_granularity_count: int, + training_fraction_split: Optional[float] = None, + validation_fraction_split: Optional[float] = None, + test_fraction_split: Optional[float] = None, predefined_split_column_name: Optional[str] = None, weight_column: Optional[str] = None, time_series_attribute_columns: Optional[List[str]] = None, @@ -4135,8 +4195,25 @@ def _run( ) -> models.Model: """Runs the training job and returns a model. - The training data splits are set by default: Roughly 80% will be used for training, - 10% for validation, and 10% for test. + If training on a Vertex AI dataset, you can use one of the following split configurations: + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by Vertex AI. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Predefined splits: + Assigns input data to training, validation, and test sets based on the value of a provided key. + If using predefined splits, ``predefined_split_column_name`` must be provided. + Supported only for tabular Datasets. + + Timestamp splits: + Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. Args: dataset (datasets.Dataset): @@ -4173,11 +4250,20 @@ def _run( Required. The number of data granularity units between data points in the training data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other values of [data_granularity_unit], must be 1. + training_fraction_split (float): + Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Optional. The fraction of the input data that is to be used to validate + the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. predefined_split_column_name (str): Optional. The key is a name of one of the Dataset's data columns. The value of the key (either the label's value or - value in the column) must be one of {``TRAIN``, - ``VALIDATE``, ``TEST``}, and it defines to which set the + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the given piece of data is assigned. If for a piece of data the key is not present or has an invalid value, that piece is ignored by the pipeline. @@ -4270,6 +4356,22 @@ def _run( training_task_definition = schema.training_job.definition.automl_forecasting + # auto-populate transformations + if self._column_transformations is None: + _LOGGER.info( + "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations." + ) + + ( + self._column_transformations, + column_names, + ) = dataset._get_default_column_transformations(target_column) + + _LOGGER.info( + "The column transformation of type 'auto' was set for the following columns: %s." + % column_names + ) + training_task_inputs_dict = { # required inputs "targetColumn": target_column, @@ -4313,16 +4415,18 @@ def _run( model = gca_model.Model( display_name=model_display_name or self._display_name, labels=model_labels or self._labels, + encryption_spec=self._model_encryption_spec, ) return self._run_job( training_task_definition=training_task_definition, training_task_inputs=training_task_inputs_dict, dataset=dataset, - training_fraction_split=None, - validation_fraction_split=None, - test_fraction_split=None, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=None, # Not supported by AutoMLForecasting model=model, ) diff --git a/google/cloud/aiplatform/utils/column_transformations_utils.py b/google/cloud/aiplatform/utils/column_transformations_utils.py new file mode 100644 index 0000000000..f0fc581b31 --- /dev/null +++ b/google/cloud/aiplatform/utils/column_transformations_utils.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, List, Optional, Tuple +import warnings + +from google.cloud.aiplatform import datasets + + +def get_default_column_transformations( + dataset: datasets._ColumnNamesDataset, target_column: str, +) -> Tuple[List[Dict[str, Dict[str, str]]], List[str]]: + """Get default column transformations from the column names, while omitting the target column. + + Args: + dataset (_ColumnNamesDataset): + Required. The dataset + target_column (str): + Required. The name of the column values of which the Model is to predict. + + Returns: + Tuple[List[Dict[str, Dict[str, str]]], List[str]]: + The default column transformations and the default column names. + """ + + column_names = [ + column_name + for column_name in dataset.column_names + if column_name != target_column + ] + column_transformations = [ + {"auto": {"column_name": column_name}} for column_name in column_names + ] + + return (column_transformations, column_names) + + +def validate_and_get_column_transformations( + column_specs: Optional[Dict[str, str]], + column_transformations: Optional[List[Dict[str, Dict[str, str]]]], +) -> List[Dict[str, Dict[str, str]]]: + """Validates column specs and transformations, then returns processed transformations. + + Args: + column_specs (Dict[str, str]): + Optional. Alternative to column_transformations where the keys of the dict + are column names and their respective values are one of + AutoMLTabularTrainingJob.column_data_types. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. Only columns with no child + should have a transformation. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + Only one of column_transformations or column_specs should be passed. + column_transformations (List[Dict[str, Dict[str, str]]]): + Optional. Transformations to apply to the input columns (i.e. columns other + than the targetColumn). Each transformation may produce multiple + result values from the column's value, and all are used for training. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. Only columns with no child + should have a transformation. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + Only one of column_transformations or column_specs should be passed. + Consider using column_specs as column_transformations will be deprecated eventually. + + Returns: + List[Dict[str, Dict[str, str]]]: + The column transformations. + + Raises: + ValueError: If both column_transformations and column_specs were provided. + """ + # user populated transformations + if column_transformations is not None and column_specs is not None: + raise ValueError( + "Both column_transformations and column_specs were passed. Only one is allowed." + ) + if column_transformations is not None: + warnings.simplefilter("always", DeprecationWarning) + warnings.warn( + "consider using column_specs instead. column_transformations will be deprecated in the future.", + DeprecationWarning, + stacklevel=2, + ) + + return column_transformations + elif column_specs is not None: + return [ + {transformation: {"column_name": column_name}} + for column_name, transformation in column_specs.items() + ] + else: + return None diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index 142301f98b..dc2e00b658 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -19,6 +19,7 @@ model as gca_model, pipeline_state as gca_pipeline_state, training_pipeline as gca_training_pipeline, + encryption_spec as gca_encryption_spec, ) from google.protobuf import json_format from google.protobuf import struct_pb2 @@ -115,6 +116,18 @@ "projects/my-project/locations/us-central1/trainingPipelines/12345" ) +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split" + @pytest.fixture def mock_pipeline_service_create(): @@ -615,3 +628,240 @@ def test_raises_before_run_is_called(self, mock_pipeline_service_create): with pytest.raises(RuntimeError): job.state + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_fraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_predefined( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_split = gca_training_pipeline.PredefinedSplit( + key=_TEST_SPLIT_PREDEFINED_COLUMN_NAME + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + predefined_split=true_split, dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_splits_default( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_time_series, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = AutoMLForecastingTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_forecasting, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + )