Skip to content

Commit

Permalink
Switched from classmethod to staticmethod where applicable and logged…
Browse files Browse the repository at this point in the history
… column names
  • Loading branch information
ivanmkc committed May 4, 2021
1 parent 332e3e2 commit 9e66508
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
12 changes: 5 additions & 7 deletions google/cloud/aiplatform/datasets/tabular_dataset.py
Expand Up @@ -85,10 +85,8 @@ def column_names(self) -> List[str]:

raise RuntimeError("No valid CSV or BigQuery datasource found.")

@classmethod
def _retrieve_gcs_source_columns(
cls, project: str, gcs_csv_file_path: str
) -> List[str]:
@staticmethod
def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
Expand Down Expand Up @@ -139,7 +137,7 @@ def _retrieve_gcs_source_columns(
header_line = line[:first_new_line_index]

# Split to make it an iterable
header_line = header_line.split("\n")
header_line = header_line.split("\n")[:1]

csv_reader = csv.reader(header_line, delimiter=",")
except:
Expand All @@ -149,8 +147,8 @@ def _retrieve_gcs_source_columns(

return next(csv_reader)

@classmethod
def _retrieve_bq_source_columns(cls, project: str, bq_table_uri: str) -> List[str]:
@staticmethod
def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
"""Retrieve the columns from a table on Google BigQuery
Example Usage:
Expand Down
16 changes: 14 additions & 2 deletions google/cloud/aiplatform/training_jobs.py
Expand Up @@ -2918,11 +2918,23 @@ def _run(
training_task_definition = schema.training_job.definition.automl_tabular

if self._column_transformations is None:
column_transformations = [
{"auto": {"column_name": column_name}}
_LOGGER.info(
"No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
)

column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]

_LOGGER.info(
"The column transformation of type 'auto' was set for the following columns: %s."
% column_names
)
else:
column_transformations = self._column_transformations

Expand Down

0 comments on commit 9e66508

Please sign in to comment.