New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added default AutoMLTabularTrainingJob column transformations #357
Changes from 11 commits
c2caaa6
29bcc70
5ce67e2
4b96837
b68e58c
6a0ac30
af0b990
ea5ef12
3300faa
783d2ea
ae5dfa1
332e3e2
9e66508
17e9f37
c4f9d6a
819cda8
c2ece02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,16 +15,22 @@ | |
# limitations under the License. | ||
# | ||
|
||
import csv | ||
from typing import Optional, Sequence, Tuple, Union | ||
|
||
from google.auth import credentials as auth_credentials | ||
|
||
from google.cloud import bigquery | ||
from google.cloud import storage | ||
|
||
from google.cloud.aiplatform import datasets | ||
from google.cloud.aiplatform.datasets import _datasources | ||
from google.cloud.aiplatform import initializer | ||
from google.cloud.aiplatform import schema | ||
from google.cloud.aiplatform import utils | ||
|
||
from typing import List | ||
ivanmkc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class TabularDataset(datasets._Dataset): | ||
"""Managed tabular dataset resource for AI Platform""" | ||
|
@@ -33,6 +39,151 @@ class TabularDataset(datasets._Dataset): | |
schema.dataset.metadata.tabular, | ||
) | ||
|
||
@property | ||
def column_names(self) -> List[str]: | ||
sasha-gitg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or | ||
Google BigQuery source. | ||
|
||
Returns: | ||
List[str] | ||
A list of columns names | ||
|
||
Raises: | ||
RuntimeError: When no valid source is found. | ||
""" | ||
|
||
metadata = self._gca_resource.metadata | ||
|
||
if metadata is None: | ||
raise RuntimeError("No metadata found for dataset") | ||
|
||
input_config = metadata.get("inputConfig") | ||
|
||
if input_config is None: | ||
raise RuntimeError("No inputConfig found for dataset") | ||
ivanmkc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
gcs_source = input_config.get("gcsSource") | ||
bq_source = input_config.get("bigquerySource") | ||
|
||
if gcs_source: | ||
gcs_source_uris = gcs_source.get("uri") | ||
|
||
if gcs_source_uris and len(gcs_source_uris) > 0: | ||
# Lexicographically sort the files | ||
gcs_source_uris.sort() | ||
|
||
# Get the first file in sorted list | ||
return TabularDataset._retrieve_gcs_source_columns( | ||
self.project, gcs_source_uris[0] | ||
) | ||
elif bq_source: | ||
bq_table_uri = bq_source.get("uri") | ||
if bq_table_uri: | ||
return TabularDataset._retrieve_bq_source_columns( | ||
self.project, bq_table_uri | ||
) | ||
|
||
raise RuntimeError("No valid CSV or BigQuery datasource found.") | ||
ivanmkc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@classmethod | ||
def _retrieve_gcs_source_columns( | ||
ivanmkc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cls, project: str, gcs_csv_file_path: str | ||
) -> List[str]: | ||
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage | ||
|
||
Example Usage: | ||
|
||
column_names = _retrieve_gcs_source_columns( | ||
"project_id", | ||
"gs://example-bucket/path/to/csv_file" | ||
) | ||
|
||
# column_names = ["column_1", "column_2"] | ||
|
||
Args: | ||
project (str): | ||
Required. Project to initiate the Google Cloud Storage client with. | ||
gcs_csv_file_path (str): | ||
Required. A full path to a CSV files stored on Google Cloud Storage. | ||
Must include "gs://" prefix. | ||
|
||
Returns: | ||
List[str] | ||
A list of columns names in the CSV file. | ||
|
||
Raises: | ||
RuntimeError: When the retrieved CSV file is invalid. | ||
""" | ||
|
||
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path( | ||
gcs_csv_file_path | ||
) | ||
client = storage.Client(project=project) | ||
bucket = client.bucket(gcs_bucket) | ||
blob = bucket.blob(gcs_blob) | ||
|
||
# Incrementally download the CSV file until the header is retrieved | ||
first_new_line_index = -1 | ||
start_index = 0 | ||
increment = 1000 | ||
line = "" | ||
|
||
try: | ||
while first_new_line_index == -1: | ||
line += blob.download_as_bytes( | ||
start=start_index, end=start_index + increment | ||
).decode("utf-8") | ||
first_new_line_index = line.find("\n") | ||
start_index += increment | ||
|
||
header_line = line[:first_new_line_index] | ||
|
||
# Split to make it an iterable | ||
header_line = header_line.split("\n") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be safer to only include the first line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
csv_reader = csv.reader(header_line, delimiter=",") | ||
sasha-gitg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except: | ||
raise RuntimeError( | ||
f"There was a problem extracting the headers from the CSV file at: { gcs_csv_file_path }" | ||
) | ||
|
||
return next(csv_reader) | ||
|
||
@classmethod | ||
def _retrieve_bq_source_columns(cls, project: str, bq_table_uri: str) -> List[str]: | ||
"""Retrieve the columns from a table on Google BigQuery | ||
|
||
Example Usage: | ||
|
||
column_names = _retrieve_bq_source_columns( | ||
"project_id", | ||
"bq://project_id.dataset.table" | ||
) | ||
|
||
# column_names = ["column_1", "column_2"] | ||
|
||
Args: | ||
project (str): | ||
Required. Project to initiate the BigQuery client with. | ||
bq_table_uri (str): | ||
Required. A URI to a BigQuery table. | ||
Can include "bq://" prefix but not required. | ||
|
||
Returns: | ||
List[str] | ||
A list of columns names in the BigQuery table. | ||
""" | ||
|
||
# Remove bq:// prefix | ||
prefix = "bq://" | ||
if bq_table_uri.startswith(prefix): | ||
bq_table_uri = bq_table_uri[len(prefix) :] | ||
|
||
client = bigquery.Client(project=project) | ||
table = client.get_table(bq_table_uri) | ||
schema = table.schema | ||
return [schema.name for schema in schema] | ||
|
||
@classmethod | ||
def create( | ||
cls, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,7 +130,6 @@ def __init__( | |
|
||
super().__init__(project=project, location=location, credentials=credentials) | ||
self._display_name = display_name | ||
self._project = project | ||
self._training_encryption_spec = initializer.global_config.get_encryption_spec( | ||
encryption_spec_key_name=training_encryption_spec_key_name | ||
) | ||
|
@@ -2918,10 +2917,18 @@ def _run( | |
|
||
training_task_definition = schema.training_job.definition.automl_tabular | ||
|
||
if self._column_transformations is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please log here we are defaulting to auto for all columns as column_transformations was not provided. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, will add. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. INFO:google.cloud.aiplatform.training_jobs:No column transformations provided, so now retrieving columns from dataset in order to set default column transformations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sasha-gitg Does this look okay or is it too verbose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought it would be nice to show the names so the user can verify the columns. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM |
||
column_transformations = [ | ||
{"auto": {"column_name": column_name}} | ||
for column_name in dataset.column_names if column_name != target_column | ||
] | ||
else: | ||
column_transformations = self._column_transformations | ||
|
||
training_task_inputs_dict = { | ||
# required inputs | ||
"targetColumn": target_column, | ||
"transformations": self._column_transformations, | ||
"transformations": column_transformations, | ||
"trainBudgetMilliNodeHours": budget_milli_node_hours, | ||
# optional inputs | ||
"weightColumnName": weight_column, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were wrong before