Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added default AutoMLTabularTrainingJob column transformations #357

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/datasets/_datasources.py
Expand Up @@ -86,9 +86,9 @@ def __init__(
raise ValueError("One of gcs_source or bq_source must be set.")

if gcs_source:
dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were wrong before

dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
elif bq_source:
dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}}
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}

self._dataset_metadata = dataset_metadata

Expand Down
155 changes: 155 additions & 0 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -15,16 +15,23 @@
# limitations under the License.
#

import csv
from typing import Optional, Sequence, Tuple, Union

from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud import storage

from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils

from typing import List
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
import logging
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import should be up top with stdlib imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed



class TabularDataset(datasets._Dataset):
"""Managed tabular dataset resource for AI Platform"""
Expand All @@ -33,6 +40,154 @@ class TabularDataset(datasets._Dataset):
schema.dataset.metadata.tabular,
)

@property
def column_names(self) -> List[str]:
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.

Returns:
List[str]
A list of columns names

Raises:
RuntimeError: When no valid source is found.
"""

metadata = self._gca_resource.metadata

if metadata is None:
raise RuntimeError("No metadata found for dataset")

input_config = metadata.get("inputConfig")

if input_config is None:
raise RuntimeError("No inputConfig found for dataset")
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

gcs_source = input_config.get("gcsSource")
bq_source = input_config.get("bigquerySource")

if gcs_source:
gcs_source_uris = gcs_source.get("uri")

if gcs_source_uris and len(gcs_source_uris) > 0:
# Lexicographically sort the files
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:

column_names = _retrieve_gcs_source_columns(
"project_id",
"gs://example-bucket/path/to/csv_file"
)

# column_names = ["column_1", "column_2"]

Args:
project (str):
Required. Project to initiate the Google Cloud Storage client with.
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.

Returns:
List[str]
A list of columns names in the CSV file.

Raises:
RuntimeError: When the retrieved CSV file is invalid.
"""

gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

# Incrementally download the CSV file until the header is retrieved
first_new_line_index = -1
start_index = 0
increment = 1000
line = ""

try:
logging.disable(logging.CRITICAL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preference to be more precise by filtering the module logs we're trying to suppress. Like so:

logger = logging.getLogger("google.auth._default")
logging_warning_filter = utils.LoggingWarningFilter()
logger.addFilter(logging_warning_filter)
credentials, _ = google.auth.default()
logger.removeFilter(logging_warning_filter)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

while first_new_line_index == -1:
line += blob.download_as_bytes(
start=start_index, end=start_index + increment
).decode("utf-8")
first_new_line_index = line.find("\n")
start_index += increment

header_line = line[:first_new_line_index]

# Split to make it an iterable
header_line = header_line.split("\n")[:1]

csv_reader = csv.reader(header_line, delimiter=",")
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
except (ValueError, RuntimeError) as err:
raise RuntimeError(
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
gcs_csv_file_path, err
)
)
finally:
logging.disable(logging.NOTSET)

return next(csv_reader)

@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery

Example Usage:

column_names = _retrieve_bq_source_columns(
"project_id",
"bq://project_id.dataset.table"
)

# column_names = ["column_1", "column_2"]

Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.

Returns:
List[str]
A list of columns names in the BigQuery table.
"""

# Remove bq:// prefix
prefix = "bq://"
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

@classmethod
def create(
cls,
Expand Down
24 changes: 22 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -130,7 +130,6 @@ def __init__(

super().__init__(project=project, location=location, credentials=credentials)
self._display_name = display_name
self._project = project
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=training_encryption_spec_key_name
)
Expand Down Expand Up @@ -2918,10 +2917,31 @@ def _run(

training_task_definition = schema.training_job.definition.automl_tabular

if self._column_transformations is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please log here we are defaulting to auto for all columns as column_transformations was not provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

@ivanmkc ivanmkc May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INFO:google.cloud.aiplatform.training_jobs:No column transformations provided, so now retrieving columns from dataset in order to set default column transformations.
INFO:google.cloud.aiplatform.training_jobs:The column transformation of type 'auto' was set for the following columns': ['station_number', 'wban_number', 'year', 'month', 'day', 'num_mean_temp_samples', 'mean_dew_point', 'num_mean_dew_point_samples', 'mean_sealevel_pressure', 'num_mean_sealevel_pressure_samples', 'mean_station_pressure', 'num_mean_station_pressure_samples', 'mean_visibility', 'num_mean_visibility_samples', 'mean_wind_speed', 'num_mean_wind_speed_samples', 'max_sustained_wind_speed', 'max_gust_wind_speed', 'max_temperature', 'max_temperature_explicit', 'min_temperature', 'min_temperature_explicit', 'total_precipitation', 'snow_depth', 'fog', 'rain', 'snow', 'hail', 'thunder', 'tornado'].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sasha-gitg Does this look okay or is it too verbose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be nice to show the names so the user can verify the columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

_LOGGER.info(
"No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
)

column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]

_LOGGER.info(
"The column transformation of type 'auto' was set for the following columns: %s."
% column_names
)
else:
column_transformations = self._column_transformations

training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
"transformations": self._column_transformations,
"transformations": column_transformations,
"trainBudgetMilliNodeHours": budget_milli_node_hours,
# optional inputs
"weightColumnName": weight_column,
Expand Down
2 changes: 1 addition & 1 deletion tests/system/aiplatform/test_dataset.py
Expand Up @@ -25,8 +25,8 @@
from google.api_core import exceptions
from google.api_core import client_options

from google.cloud import storage
from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset
Expand Down
87 changes: 85 additions & 2 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -34,10 +34,16 @@
_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name"
_TEST_DATASET_NAME = "test-dataset-name"
_TEST_DISPLAY_NAME = "test-display-name"
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image

_TEST_TRAINING_COLUMN_NAMES = [
"sepal_width",
"sepal_length",
"petal_length",
"petal_width",
]

_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [
{"auto": {"column_name": "sepal_width"}},
{"auto": {"column_name": "sepal_length"}},
Expand Down Expand Up @@ -169,7 +175,9 @@ def mock_dataset_tabular():
name=_TEST_DATASET_NAME,
metadata={},
)
return ds
ds.column_names = _TEST_TRAINING_COLUMN_NAMES

yield ds


@pytest.fixture
Expand Down Expand Up @@ -347,6 +355,81 @@ def test_run_call_pipeline_if_no_model_display_name(
training_pipeline=true_training_pipeline,
)

@pytest.mark.parametrize("sync", [True, False])
# This test checks that default transformations are used if no columns transformations are provided
def test_run_call_pipeline_service_create_if_no_column_transformations(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_dataset_tabular,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.AutoMLTabularTrainingJob(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
column_transformations=None,
optimization_objective_recall_value=None,
optimization_objective_precision_value=None,
)

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
)

if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.automl_tabular,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down