diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py index 687f3a51ab..058926f386 100644 --- a/google/cloud/aiplatform/datasets/tabular_dataset.py +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -85,10 +85,8 @@ def column_names(self) -> List[str]: raise RuntimeError("No valid CSV or BigQuery datasource found.") - @classmethod - def _retrieve_gcs_source_columns( - cls, project: str, gcs_csv_file_path: str - ) -> List[str]: + @staticmethod + def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]: """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage Example Usage: @@ -139,7 +137,7 @@ def _retrieve_gcs_source_columns( header_line = line[:first_new_line_index] # Split to make it an iterable - header_line = header_line.split("\n") + header_line = header_line.split("\n")[:1] csv_reader = csv.reader(header_line, delimiter=",") except: @@ -149,8 +147,8 @@ def _retrieve_gcs_source_columns( return next(csv_reader) - @classmethod - def _retrieve_bq_source_columns(cls, project: str, bq_table_uri: str) -> List[str]: + @staticmethod + def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]: """Retrieve the columns from a table on Google BigQuery Example Usage: diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 7cf5341d02..a1c85127d7 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -2918,11 +2918,23 @@ def _run( training_task_definition = schema.training_job.definition.automl_tabular if self._column_transformations is None: - column_transformations = [ - {"auto": {"column_name": column_name}} + _LOGGER.info( + "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations." + ) + + column_names = [ + column_name for column_name in dataset.column_names if column_name != target_column ] + column_transformations = [ + {"auto": {"column_name": column_name}} for column_name in column_names + ] + + _LOGGER.info( + "The column transformation of type 'auto' was set for the following columns: %s." + % column_names + ) else: column_transformations = self._column_transformations