-
Notifications
You must be signed in to change notification settings - Fork 287
/
_utils.py
492 lines (389 loc) · 17.2 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
"""Miscellaneous utility functions."""
import operator
import os
import pickle
import uuid
import warnings
from collections import Counter, defaultdict
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from sklearn.discriminant_analysis import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sdv import version
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError
def _cast_to_iterable(value):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
if isinstance(value, (list, tuple)):
return value
return [value]
def _get_datetime_format(value):
"""Get the ``strftime`` format for a given ``value``.
This function returns the ``strftime`` format of a given ``value`` when possible.
If the ``_guess_datetime_format_for_array`` from ``pandas.core.tools.datetimes`` is
able to detect the ``strftime`` it will return it as a ``string`` if not, a ``None``
will be returned.
Args:
value (pandas.Series, np.ndarray, list, or str):
Input to attempt detecting the format.
Return:
String representing the datetime format in ``strftime`` format or ``None`` if not detected.
"""
if not isinstance(value, pd.Series):
value = pd.Series(value)
value = value[~value.isna()]
value = value.astype(str).to_numpy()
return _guess_datetime_format_for_array(value)
def _is_datetime_type(value):
"""Determine if the input is a datetime type or not.
If a ``pandas.Series`` or ``list`` is passed, it will return ``True`` if the first
thousand values are datetime. Otherwise, it will check if the value is a datetime.
Note: it will return ``False`` if ``value`` is a string representing
a date before the year 1677.
Args:
value (array-like iterable, int, str or datetime):
Input to evaluate.
Returns:
bool:
True if the input is a datetime type, False if not.
"""
if isinstance(value, str) or (not isinstance(value, Iterable)):
value = _cast_to_iterable(value)
values = pd.Series(value)
values = values[~values.isna()]
values = values.head(1000) # only check 1000 values so this method takes less than 1 second
for value in values:
if not (
bool(_get_datetime_format([value]))
or isinstance(value, pd.Timestamp)
or isinstance(value, datetime)
):
return False
return True
def _is_numerical_type(value):
"""Determine if the input is numerical or not.
Args:
value (int, str, datetime, bool):
Input to evaluate.
Returns:
bool:
True if the input is numerical, False if not.
"""
return pd.isna(value) | pd.api.types.is_float(value) | pd.api.types.is_integer(value)
def _is_boolean_type(value):
"""Determine if the input is a boolean or not.
Args:
value (int, str, datetime, bool):
Input to evaluate.
Returns:
bool:
True if the input is a boolean, False if not.
"""
return True if pd.isna(value) | (value is True) | (value is False) else False
def _validate_datetime_format(column, datetime_format):
"""Determine the values of the column that match the datetime format.
Args:
column (pd.Series):
Column to evaluate.
datetime_format (str):
The datetime format.
Returns:
pd.Series:
Series of booleans, with True if the value matches the format, False if not.
"""
pandas_datetime_format = datetime_format.replace('%-', '%')
datetime_column = pd.to_datetime(
column,
errors='coerce',
format=pandas_datetime_format
)
valid = pd.isna(column) | ~pd.isna(datetime_column)
return set(column[~valid])
def _convert_to_timedelta(column):
"""Convert a ``pandas.Series`` to one with dtype ``timedelta``.
``pd.to_timedelta`` does not handle nans, so this function masks the nans, converts and then
reinserts them.
Args:
column (pandas.Series):
Column to convert.
Returns:
pandas.Series:
The column converted to timedeltas.
"""
nan_mask = pd.isna(column)
column[nan_mask] = 0
column = pd.to_timedelta(column)
column[nan_mask] = pd.NaT
return column
def _load_data_from_csv(filepath, read_csv_parameters=None):
"""Load DataFrame from a filepath.
Args:
filepath (str):
String that represents the ``path`` to the ``csv`` file.
read_csv_parameters (dict):
A python dictionary of with string and value accepted by ``pandas.read_csv``
function. Defaults to ``None``.
"""
filepath = Path(filepath)
read_csv_parameters = read_csv_parameters or {}
data = pd.read_csv(filepath, **read_csv_parameters)
return data
def _groupby_list(list_to_check):
"""Return the first element of the list if the length is 1 else the entire list."""
return list_to_check[0] if len(list_to_check) == 1 else list_to_check
def _create_unique_name(name, list_names):
"""Modify the ``name`` parameter if it already exists in the list of names."""
result = name
while result in list_names:
result += '_'
return result
def _format_invalid_values_string(invalid_values, num_values):
"""Convert ``invalid_values`` into a string of invalid values.
Args:
invalid_values (pd.DataFrame, set):
Object of values to be converted into string.
num_values (int):
Maximum number of values of the object to show.
Returns:
str:
A stringified version of the object.
"""
if isinstance(invalid_values, pd.DataFrame):
if len(invalid_values) > num_values:
return f'{invalid_values.head(num_values)}\n+{len(invalid_values) - num_values} more'
if isinstance(invalid_values, set):
invalid_values = sorted(invalid_values, key=lambda x: str(x))
if len(invalid_values) > num_values:
extra_missing_values = [f'+ {len(invalid_values) - num_values} more']
return f'{invalid_values[:num_values] + extra_missing_values}'
return f'{invalid_values}'
def _validate_foreign_keys_not_null(metadata, data):
"""Validate that the foreign keys in the data don't have null values."""
invalid_tables = defaultdict(list)
for table_name, table_data in data.items():
for foreign_key in metadata._get_all_foreign_keys(table_name):
if table_data[foreign_key].isna().any():
invalid_tables[table_name].append(foreign_key)
if invalid_tables:
err_msg = (
'The data contains null values in foreign key columns. '
'This feature is currently unsupported. Please remove '
'null values to fit the synthesizer.\n'
'\n'
'Affected columns:\n'
)
for table_name, invalid_columns in invalid_tables.items():
err_msg += f"Table '{table_name}', column(s) {invalid_columns}\n"
raise SynthesizerInputError(err_msg)
def check_sdv_versions_and_warn(synthesizer):
"""Check if the current SDV and SDV Enterprise versions mismatch.
Args:
synthesizer (BaseSynthesizer or BaseMultiTableSynthesizer):
An SDV model instance to check versions against.
Raises:
SDVVersionWarning:
If the current SDV or SDV Enterprise version does not match the version used to fit
the synthesizer.
"""
current_public_version = getattr(version, 'public', None)
current_enterprise_version = getattr(version, 'enterprise', None)
if synthesizer._fitted:
fitted_public_version = getattr(synthesizer, '_fitted_sdv_version', None)
fitted_enterprise_version = getattr(synthesizer, '_fitted_sdv_enterprise_version', None)
public_missmatch = current_public_version != fitted_public_version
enterprise_missmatch = current_enterprise_version != fitted_enterprise_version
if (public_missmatch or enterprise_missmatch):
static_message = (
'The latest bug fixes and features may not be available for this synthesizer. '
'To see these enhancements, create and train a new synthesizer on this version.'
)
if public_missmatch and enterprise_missmatch:
message = (
'You are currently on SDV version '
f'{current_public_version} and SDV Enterprise version '
f'{current_enterprise_version} but this synthesizer was created on '
f'SDV version {synthesizer._fitted_sdv_version} and SDV Enterprise version '
f'{synthesizer._fitted_sdv_enterprise_version}.'
)
elif public_missmatch:
message = (
'You are currently on SDV version '
f'{current_public_version} but this synthesizer was created on '
f'version {synthesizer._fitted_sdv_version}.'
)
elif enterprise_missmatch:
message = (
'You are currently on SDV Enterprise version '
f'{current_enterprise_version} but this synthesizer was created on '
f'version {synthesizer._fitted_sdv_enterprise_version}.'
)
message = f'{message} {static_message}'
warnings.warn(message, SDVVersionWarning)
def _compare_versions(current_version, synthesizer_version, compare_operator=operator.gt):
"""Compare two versions.
Given a ``compare_operator`` compare two versions using that operator to determine if one is
greater than the other or vice-versa.
Args:
current_version (str):
The current version to compare against, formatted as a string with major, minor, and
revision parts separated by periods (e.g., "1.0.0").
synthesizer_version (str):
The synthesizer version to compare, formatted as a string with major, minor, and
revision parts separated by periods (e.g., "1.0.0")
compare_operator (operator):
Operator function to evaluate with. Defaults to ``operator.gt``.
Returns:
bool:
Depending on the ``operator`` function it will return ``True`` or ``False`` if
``current_version`` is bigger or lower than ``synthesizer_version``.
"""
if None in (current_version, synthesizer_version):
return False
current_version = current_version.split('.')
synthesizer_version = synthesizer_version.split('.')
for current_v, synth_v in zip(current_version, synthesizer_version):
try:
current_v = int(current_v)
synth_v = int(synth_v)
if compare_operator(current_v, synth_v):
return False
if compare_operator(synth_v, current_v):
return True
except Exception:
pass
return False
def check_synthesizer_version(synthesizer, is_fit_method=False, compare_operator=operator.gt):
"""Check if the current synthesizer version is greater than the package version.
Args:
synthesizer (BaseSynthesizer or BaseMultiTableSynthesizer):
An SDV model instance to check versions against.
is_fit_method (bool):
Whether or not this function is being called by a ``fit`` function.
compare_operator (operator):
Operator function to evaluate with. Defaults to ``operator.gt``.
Raises:
VersionError:
If the current version of the software is lower than the synthesizer's version.
"""
current_public_version = getattr(version, 'public', None)
current_enterprise_version = getattr(version, 'enterprise', None)
static_message = 'Downgrading your SDV version is not supported.'
if is_fit_method:
static_message = (
'Fitting this synthesizer again is not supported. '
'Please create a new synthesizer.'
)
fit_public_version = getattr(synthesizer, '_fitted_sdv_version', None)
fit_enterprise_version = getattr(synthesizer, '_fitted_sdv_enterprise_version', None)
is_public_lower = _compare_versions(
current_public_version,
fit_public_version,
compare_operator
)
is_enterprise_lower = _compare_versions(
current_enterprise_version,
fit_enterprise_version,
compare_operator
)
if is_public_lower and is_enterprise_lower:
raise VersionError(
f'You are currently on SDV version {current_public_version} and SDV Enterprise '
f'version {current_enterprise_version} but this '
f'synthesizer was created on SDV version {fit_public_version} and SDV '
f'Enterprise version {fit_enterprise_version}. {static_message}'
)
if is_public_lower:
raise VersionError(
f'You are currently on SDV version {current_public_version} but this '
f'synthesizer was created on version {fit_public_version}. {static_message}'
)
if is_enterprise_lower:
raise VersionError(
f'You are currently on SDV Enterprise version {current_enterprise_version} but '
f'this synthesizer was created on version {fit_enterprise_version}. '
f'{static_message}'
)
def _get_root_tables(relationships):
parent_tables = {rel['parent_table_name'] for rel in relationships}
child_tables = {rel['child_table_name'] for rel in relationships}
return parent_tables - child_tables
def generate_synthesizer_id(synthesizer):
"""Generate a unique identifier for the synthesizer instance.
This method creates a unique identifier by combining the class name, the public SDV version
and the last part of a UUID4 composed by 36 random characters.
Args:
synthesizer (BaseSynthesizer or BaseMultiTableSynthesizer):
An SDV model instance to check versions against.
Returns:
ID:
A unique identifier for this synthesizer.
"""
class_name = synthesizer.__class__.__name__
synth_version = version.public
unique_id = ''.join(str(uuid.uuid4()).split('-'))
return f'{class_name}_{synth_version}_{unique_id}'
def _generate_feature_vector(data, foreign_key):
parent_name = foreign_key[0]
parent_col, child_col = data[foreign_key[0]][foreign_key[1]], data[foreign_key[2]][foreign_key[3]]
parent_set, child_set = set(parent_col), set(child_col)
return [
len(child_set) / (len(parent_set) + 1e-5),
len(child_set) / (len(child_col) + 1e-5),
1.0 if parent_col.name == child_col.name else 0.0,
1.0 if child_col.name.lower().endswith('id') or child_col.name.lower().endswith('key') else 0.0,
1.0 if parent_name[:-1] in child_col else 0.0,
]
def confusion_matrix(set1, set2):
true_positive, false_positive, false_negative = set(), set(), set()
for key in set1:
if key in set2:
true_positive.add(key)
else:
false_positive.add(key)
for key in set2:
if key not in set1:
false_negative.add(key)
return {
'True Positive': true_positive,
'False Positive': false_positive,
'False Negative': false_negative
}
def train_foreign_key_detector():
"""Generate a foreign key detection model using logistic regression and pickle it.
This function is used to create and train a foreign key detection model.
"""
features, target = np.empty(shape=(0,5)), np.empty(shape=(0,))
pipeline = Pipeline([
('scaler', StandardScaler()),
('detector', LogisticRegression())
])
# Load the data
for demo_name in os.listdir('test_set'):
with open(f'test_set/{demo_name}/relationships.pkl', 'rb') as f:
true_relationships = pickle.load(f)
with open(f'predicted/{demo_name}/relationships.pkl', 'rb') as f:
predicted_relationships = pickle.load(f)
data = {}
for table_name in os.listdir(f'test_set/{demo_name}'):
if table_name.endswith('.csv'):
data[table_name[:-4]] = pd.read_csv(f'test_set/{demo_name}/{table_name}', low_memory=False)
cm = confusion_matrix(predicted_relationships, true_relationships)
for foreign_key in cm['True Positive']:
features = np.vstack((features, _generate_feature_vector(data, foreign_key)))
target = np.append(target, 1.)
for foreign_key in cm['False Positive']:
features = np.vstack((features, _generate_feature_vector(data, foreign_key)))
target = np.append(target, 0.)
pipeline.fit(features, target)
with open('trained_model.pkl', 'wb') as f:
pickle.dump(pipeline, f)
def predict_foreign_keys(data, parent_candidate, primary_key, child_candidate, column_name, threshold):
features = np.array(_generate_feature_vector(data, (parent_candidate, primary_key, child_candidate, column_name))).reshape(1, -1)
trained_model = pickle.load(open('trained_model.pkl', 'rb'))
if trained_model.predict_proba(features)[0, 1] > threshold:
return True
return False