-
Notifications
You must be signed in to change notification settings - Fork 316
/
column_transformations_utils.py
107 lines (93 loc) · 4.39 KB
/
column_transformations_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
from typing import Dict, List, Optional, Tuple
import warnings
def get_default_column_transformations(
dataset: _ColumnNamesDataset, target_column: str,
) -> Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
"""Get default column transformations from the column names, while omitting the target column.
Args:
dataset (_ColumnNamesDataset):
Required. The dataset
target_column (str):
Required. The name of the column values of which the Model is to predict.
Returns:
Tuple[Dict[str, Dict[str, Union[bool, str]]], List[str]]:
The default column transformations and the default column names.
"""
column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]
return (column_transformations, column_names)
def validate_and_get_column_transformations(
column_specs: Optional[Dict[str, str]],
column_transformations: Optional[List[Dict[str, Dict[str, str]]]],
) -> Dict:
"""Validates column specs and transformations, then returns processed transformations.
Args:
column_specs (Dict[str, str]):
Optional. Alternative to column_transformations where the keys of the dict
are column names and their respective values are one of
AutoMLTabularTrainingJob.column_data_types.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
column_transformations (List[Dict[str, Dict[str, str]]]):
Optional. Transformations to apply to the input columns (i.e. columns other
than the targetColumn). Each transformation may produce multiple
result values from the column's value, and all are used for training.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
Consider using column_specs as column_transformations will be deprecated eventually.
Returns:
List[Dict[str, Dict[str, str]]]:
The column transformations.
"""
# user populated transformations
if column_transformations is not None and column_specs is not None:
raise ValueError(
"Both column_transformations and column_specs were passed. Only one is allowed."
)
if column_transformations is not None:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"consider using column_specs instead. column_transformations will be deprecated in the future.",
DeprecationWarning,
stacklevel=2,
)
return column_transformations
elif column_specs is not None:
return [
{transformation: {"column_name": column_name}}
for column_name, transformation in column_specs.items()
]
else:
return None