Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added default AutoMLTabularTrainingJob column transformations #357

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/datasets/_datasources.py
Expand Up @@ -86,9 +86,9 @@ def __init__(
raise ValueError("One of gcs_source or bq_source must be set.")

if gcs_source:
dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were wrong before

dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
elif bq_source:
dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}}
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}

self._dataset_metadata = dataset_metadata

Expand Down
151 changes: 151 additions & 0 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -15,16 +15,22 @@
# limitations under the License.
#

import csv
from typing import Optional, Sequence, Tuple, Union

from google.auth import credentials as auth_credentials

from google.cloud import bigquery
from google.cloud import storage

from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils

from typing import List
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved


class TabularDataset(datasets._Dataset):
"""Managed tabular dataset resource for AI Platform"""
Expand All @@ -33,6 +39,151 @@ class TabularDataset(datasets._Dataset):
schema.dataset.metadata.tabular,
)

@property
def column_names(self) -> List[str]:
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.

Returns:
List[str]
A list of columns names

Raises:
RuntimeError: When no valid source is found.
"""

metadata = self._gca_resource.metadata

if metadata is None:
raise RuntimeError("No metadata found for dataset")

input_config = metadata.get("inputConfig")

if input_config is None:
raise RuntimeError("No inputConfig found for dataset")
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

gcs_source = input_config.get("gcsSource")
bq_source = input_config.get("bigquerySource")

if gcs_source:
gcs_source_uris = gcs_source.get("uri")

if gcs_source_uris and len(gcs_source_uris) > 0:
# Lexicographically sort the files
gcs_source_uris.sort()

# Get the first file in sorted list
return TabularDataset._retrieve_gcs_source_columns(
self.project, gcs_source_uris[0]
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
return TabularDataset._retrieve_bq_source_columns(
self.project, bq_table_uri
)

raise RuntimeError("No valid CSV or BigQuery datasource found.")
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _retrieve_gcs_source_columns(
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
cls, project: str, gcs_csv_file_path: str
) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage

Example Usage:

column_names = _retrieve_gcs_source_columns(
"project_id",
"gs://example-bucket/path/to/csv_file"
)

# column_names = ["column_1", "column_2"]

Args:
project (str):
Required. Project to initiate the Google Cloud Storage client with.
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.

Returns:
List[str]
A list of columns names in the CSV file.

Raises:
RuntimeError: When the retrieved CSV file is invalid.
"""

gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)

# Incrementally download the CSV file until the header is retrieved
first_new_line_index = -1
start_index = 0
increment = 1000
line = ""

try:
while first_new_line_index == -1:
line += blob.download_as_bytes(
start=start_index, end=start_index + increment
).decode("utf-8")
first_new_line_index = line.find("\n")
start_index += increment

header_line = line[:first_new_line_index]

# Split to make it an iterable
header_line = header_line.split("\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be safer to only include the first line header_line.split("\n")[:1] to avoid possible parsing errors down stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


csv_reader = csv.reader(header_line, delimiter=",")
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
except:
raise RuntimeError(
f"There was a problem extracting the headers from the CSV file at: { gcs_csv_file_path }"
)

return next(csv_reader)

@classmethod
def _retrieve_bq_source_columns(cls, project: str, bq_table_uri: str) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery

Example Usage:

column_names = _retrieve_bq_source_columns(
"project_id",
"bq://project_id.dataset.table"
)

# column_names = ["column_1", "column_2"]

Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.

Returns:
List[str]
A list of columns names in the BigQuery table.
"""

# Remove bq:// prefix
prefix = "bq://"
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]

client = bigquery.Client(project=project)
table = client.get_table(bq_table_uri)
schema = table.schema
return [schema.name for schema in schema]

@classmethod
def create(
cls,
Expand Down
11 changes: 9 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -130,7 +130,6 @@ def __init__(

super().__init__(project=project, location=location, credentials=credentials)
self._display_name = display_name
self._project = project
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=training_encryption_spec_key_name
)
Expand Down Expand Up @@ -2918,10 +2917,18 @@ def _run(

training_task_definition = schema.training_job.definition.automl_tabular

if self._column_transformations is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please log here we are defaulting to auto for all columns as column_transformations was not provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

@ivanmkc ivanmkc May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INFO:google.cloud.aiplatform.training_jobs:No column transformations provided, so now retrieving columns from dataset in order to set default column transformations.
INFO:google.cloud.aiplatform.training_jobs:The column transformation of type 'auto' was set for the following columns': ['station_number', 'wban_number', 'year', 'month', 'day', 'num_mean_temp_samples', 'mean_dew_point', 'num_mean_dew_point_samples', 'mean_sealevel_pressure', 'num_mean_sealevel_pressure_samples', 'mean_station_pressure', 'num_mean_station_pressure_samples', 'mean_visibility', 'num_mean_visibility_samples', 'mean_wind_speed', 'num_mean_wind_speed_samples', 'max_sustained_wind_speed', 'max_gust_wind_speed', 'max_temperature', 'max_temperature_explicit', 'min_temperature', 'min_temperature_explicit', 'total_precipitation', 'snow_depth', 'fog', 'rain', 'snow', 'hail', 'thunder', 'tornado'].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sasha-gitg Does this look okay or is it too verbose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be nice to show the names so the user can verify the columns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

column_transformations = [
{"auto": {"column_name": column_name}}
for column_name in dataset.column_names()
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
]
else:
column_transformations = self._column_transformations

training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
"transformations": self._column_transformations,
"transformations": column_transformations,
"trainBudgetMilliNodeHours": budget_milli_node_hours,
# optional inputs
"weightColumnName": weight_column,
Expand Down
95 changes: 93 additions & 2 deletions tests/unit/aiplatform/test_automl_tabular_training_jobs.py
Expand Up @@ -34,10 +34,16 @@
_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name"
_TEST_DATASET_NAME = "test-dataset-name"
_TEST_DISPLAY_NAME = "test-display-name"
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular
_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image

_TEST_TRAINING_COLUMN_NAMES = [
"sepal_width",
"sepal_length",
"petal_length",
"petal_width",
]

_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [
{"auto": {"column_name": "sepal_width"}},
{"auto": {"column_name": "sepal_length"}},
Expand Down Expand Up @@ -169,7 +175,17 @@ def mock_dataset_tabular():
name=_TEST_DATASET_NAME,
metadata={},
)
return ds

yield ds


@pytest.fixture
def mock_dataset_tabular_column_names(mock_dataset_tabular):
with mock.patch.object(
mock_dataset_tabular, "column_names", new_callable=mock.PropertyMock
) as mock_dataset_tabular_column_names:
mock_dataset_tabular_column_names.return_value = _TEST_TRAINING_COLUMN_NAMES
yield mock_dataset_tabular_column_names


@pytest.fixture
Expand Down Expand Up @@ -347,6 +363,81 @@ def test_run_call_pipeline_if_no_model_display_name(
training_pipeline=true_training_pipeline,
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_if_no_column_transformations(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_dataset_tabular,
mock_dataset_tabular_column_names,
mock_model_service_get,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_BUCKET_NAME,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = training_jobs.AutoMLTabularTrainingJob(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE,
column_transformations=None,
optimization_objective_recall_value=None,
optimization_objective_precision_value=None,
)

model_from_job = job.run(
dataset=mock_dataset_tabular,
target_column=_TEST_TRAINING_TARGET_COLUMN,
model_display_name=_TEST_MODEL_DISPLAY_NAME,
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
sync=sync,
)

if not sync:
model_from_job.wait()

true_fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
test_fraction=_TEST_TEST_FRACTION_SPLIT,
)

true_managed_model = gca_model.Model(
display_name=_TEST_MODEL_DISPLAY_NAME,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=true_fraction_split,
predefined_split=gca_training_pipeline.PredefinedSplit(
key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME
),
dataset_id=mock_dataset_tabular.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=schema.training_job.definition.automl_tabular,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
)

@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
Expand Down