Skip to content

Commit

Permalink
feat: Added default AutoMLTabularTrainingJob column transformations (#…
Browse files Browse the repository at this point in the history
…357)

* Added default column_transformation code

* Added docstrings

* Added tests and moved code to tabular_dataset

* Switched to using BigQuery.Table instead of custom SQL query

* Fixed bigquery unit test

* Added GCS test

* Fixed issues with incorrect input config parameter

* Added test for AutoMLTabularTrainingJob for no transformations

* Added comment

* Fixed test

* Ran linter

* Switched from classmethod to staticmethod where applicable and logged column names

* Added extra dataset tests

* Added logging suppression

* Fixed lint errors

* Switched logging filter method

Co-authored-by: sasha-gitg <44654632+sasha-gitg@users.noreply.github.com>
  • Loading branch information
ivanmkc and sasha-gitg committed May 5, 2021
1 parent beb4032 commit 4fce8c4
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 23 deletions.
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/datasets/_datasources.py
Expand Up @@ -86,9 +86,9 @@ def __init__(
raise ValueError("One of gcs_source or bq_source must be set.")

if gcs_source:
dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}}
dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
elif bq_source:
dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}}
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}

self._dataset_metadata = dataset_metadata

Expand Down
159 changes: 158 additions & 1 deletion google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -15,10 +15,16 @@
# limitations under the License.
#

from typing import Optional, Sequence, Tuple, Union
import csv
import logging

from typing import List, Optional, Sequence, Tuple, Union

from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud import storage

from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
Expand All @@ -33,6 +39,157 @@ class TabularDataset(datasets._Dataset):
schema.dataset.metadata.tabular,
)

@property
def column_names(self) -> List[str]:
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.
Returns:
List[str]
A list of columns names
Raises:
RuntimeError: When no valid source is found.
"""

metadata = self._gca_resource.metadata

if metadata is None:
raise RuntimeError("No metadata found for dataset")

input_config = metadata.get("inputConfig")

if input_config is None:
raise RuntimeError("No inputConfig found for dataset")

gcs_source = input_config.get("gcsSource")
bq_source = input_config.get("bigquerySource")

if gcs_source:
gcs_source_uris = gcs_source.get("uri")

if gcs_source_uris and len(gcs_source_uris) > 0:
# Lexicographically sort the files
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")

@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
column_names = _retrieve_gcs_source_columns(
"project_id",
"gs://example-bucket/path/to/csv_file"
)
# column_names = ["column_1", "column_2"]
Args:
project (str):
Required. Project to initiate the Google Cloud Storage client with.
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.
Returns:
List[str]
A list of columns names in the CSV file.
Raises:
RuntimeError: When the retrieved CSV file is invalid.
"""

gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

# Incrementally download the CSV file until the header is retrieved
first_new_line_index = -1
start_index = 0
increment = 1000
line = ""

try:
logger = logging.getLogger("google.resumable_media._helpers")
logging_warning_filter = utils.LoggingFilter(logging.INFO)
logger.addFilter(logging_warning_filter)

while first_new_line_index == -1:
line += blob.download_as_bytes(
start=start_index, end=start_index + increment
).decode("utf-8")
first_new_line_index = line.find("\n")
start_index += increment

header_line = line[:first_new_line_index]

# Split to make it an iterable
header_line = header_line.split("\n")[:1]

csv_reader = csv.reader(header_line, delimiter=",")
except (ValueError, RuntimeError) as err:
raise RuntimeError(
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
gcs_csv_file_path, err
)
)
finally:
logger.removeFilter(logging_warning_filter)

return next(csv_reader)

@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
Example Usage:
column_names = _retrieve_bq_source_columns(
"project_id",
"bq://project_id.dataset.table"
)
# column_names = ["column_1", "column_2"]
Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
Returns:
List[str]
A list of columns names in the BigQuery table.
"""

# Remove bq:// prefix
prefix = "bq://"
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

@classmethod
def create(
cls,
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/initializer.py
Expand Up @@ -170,7 +170,7 @@ def credentials(self) -> Optional[auth_credentials.Credentials]:
if self._credentials:
return self._credentials
logger = logging.getLogger("google.auth._default")
logging_warning_filter = utils.LoggingWarningFilter()
logging_warning_filter = utils.LoggingFilter(logging.WARNING)
logger.addFilter(logging_warning_filter)
credentials, _ = google.auth.default()
logger.removeFilter(logging_warning_filter)
Expand Down
24 changes: 22 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -130,7 +130,6 @@ def __init__(

super().__init__(project=project, location=location, credentials=credentials)
self._display_name = display_name
self._project = project
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=training_encryption_spec_key_name
)
Expand Down Expand Up @@ -2955,10 +2954,31 @@ def _run(

training_task_definition = schema.training_job.definition.automl_tabular

if self._column_transformations is None:
_LOGGER.info(
"No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
)

column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]

_LOGGER.info(
"The column transformation of type 'auto' was set for the following columns: %s."
% column_names
)
else:
column_transformations = self._column_transformations

training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
"transformations": self._column_transformations,
"transformations": column_transformations,
"trainBudgetMilliNodeHours": budget_milli_node_hours,
# optional inputs
"weightColumnName": weight_column,
Expand Down
7 changes: 5 additions & 2 deletions google/cloud/aiplatform/utils.py
Expand Up @@ -468,6 +468,9 @@ class PredictionClientWithOverride(ClientWithOverride):
)


class LoggingWarningFilter(logging.Filter):
class LoggingFilter(logging.Filter):
def __init__(self, warning_level: int):
self._warning_level = warning_level

def filter(self, record):
return record.levelname == logging.WARNING
return record.levelname == self._warning_level
2 changes: 1 addition & 1 deletion tests/system/aiplatform/test_dataset.py
Expand Up @@ -25,8 +25,8 @@
from google.api_core import exceptions
from google.api_core import client_options

from google.cloud import storage
from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset
Expand Down
87 changes: 85 additions & 2 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -34,10 +34,16 @@
_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name"
_TEST_DATASET_NAME = "test-dataset-name"
_TEST_DISPLAY_NAME = "test-display-name"
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image

_TEST_TRAINING_COLUMN_NAMES = [
"sepal_width",
"sepal_length",
"petal_length",
"petal_width",
]

_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [
{"auto": {"column_name": "sepal_width"}},
{"auto": {"column_name": "sepal_length"}},
Expand Down Expand Up @@ -169,7 +175,9 @@ def mock_dataset_tabular():
name=_TEST_DATASET_NAME,
metadata={},
)
return ds
ds.column_names = _TEST_TRAINING_COLUMN_NAMES

yield ds


@pytest.fixture
Expand Down Expand Up @@ -347,6 +355,81 @@ def test_run_call_pipeline_if_no_model_display_name(
training_pipeline=true_training_pipeline,
)

@pytest.mark.parametrize("sync", [True, False])
# This test checks that default transformations are used if no columns transformations are provided
def test_run_call_pipeline_service_create_if_no_column_transformations(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.AutoMLTabularTrainingJob(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
column_transformations=None,
optimization_objective_recall_value=None,
optimization_objective_precision_value=None,
)

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
)

if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.automl_tabular,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down

0 comments on commit 4fce8c4

Please sign in to comment.