Skip to content

Commit

Permalink
Convert integer column names to strings to allow for default column n…
Browse files Browse the repository at this point in the history
…ames (#1976)
  • Loading branch information
lajohn4747 committed May 8, 2024
1 parent aa5188d commit e1f787e
Show file tree
Hide file tree
Showing 12 changed files with 360 additions and 6 deletions.
2 changes: 2 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def _validate_foreign_keys_not_null(metadata, data):
invalid_tables = defaultdict(list)
for table_name, table_data in data.items():
for foreign_key in metadata._get_all_foreign_keys(table_name):
if foreign_key not in table_data and int(foreign_key) in table_data:
foreign_key = int(foreign_key)
if table_data[foreign_key].isna().any():
invalid_tables[table_name].append(foreign_key)

Expand Down
7 changes: 6 additions & 1 deletion sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,12 @@ def _set_metadata_dict(self, metadata):
self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict)

for relationship in metadata.get('relationships', []):
self.relationships.append(relationship)
type_safe_relationships = {
key: str(value)
if not isinstance(value, str)
else value for key, value in relationship.items()
}
self.relationships.append(type_safe_relationships)

@classmethod
def load_from_dict(cls, metadata_dict):
Expand Down
6 changes: 6 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,12 @@ def load_from_dict(cls, metadata_dict):
for key in instance._KEYS:
value = deepcopy(metadata_dict.get(key))
if value:
if key == 'columns':
value = {
str(key)
if not isinstance(key, str)
else key: col for key, col in value.items()
}
setattr(instance, f'{key}', value)

return instance
Expand Down
24 changes: 24 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.extended_columns = defaultdict(dict)
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
self._original_table_columns = {}
if synthesizer_kwargs is not None:
warn_message = (
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
Expand Down Expand Up @@ -326,6 +327,19 @@ def update_transformers(self, table_name, column_name_to_transformer):
self._validate_table_name(table_name)
self._table_synthesizers[table_name].update_transformers(column_name_to_transformer)

def _store_and_convert_original_cols(self, data):
list_of_changed_tables = []
for table, dataframe in data.items():
self._original_table_columns[table] = dataframe.columns
for column in dataframe.columns:
if isinstance(column, int):
dataframe.columns = dataframe.columns.astype(str)
list_of_changed_tables.append(table)
break

data[table] = dataframe
return list_of_changed_tables

def preprocess(self, data):
"""Transform the raw data to numerical space.
Expand All @@ -337,6 +351,8 @@ def preprocess(self, data):
dict:
A dictionary with the preprocessed data.
"""
list_of_changed_tables = self._store_and_convert_original_cols(data)

self.validate(data)
if self._fitted:
warnings.warn(
Expand All @@ -351,6 +367,9 @@ def preprocess(self, data):
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)

for table in list_of_changed_tables:
data[table].columns = self._original_table_columns[table]

return processed_data

def _model_tables(self, augmented_data):
Expand Down Expand Up @@ -487,6 +506,11 @@ def sample(self, scale=1.0):
total_rows += len(table)
total_columns += len(table.columns)

table_columns = getattr(self, '_original_table_columns', {})
for table in sampled_data:
if table in table_columns:
sampled_data[table].columns = table_columns[table]

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
Expand Down
26 changes: 24 additions & 2 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
enforce_min_max_values=self.enforce_min_max_values,
locales=self.locales,
)
self._original_columns = pd.Index([])
self._fitted = False
self._random_state_set = False
self._update_default_transformers()
Expand Down Expand Up @@ -367,6 +368,16 @@ def _preprocess(self, data):
self._data_processor.fit(data)
return self._data_processor.transform(data)

def _store_and_convert_original_cols(self, data):
# Transform in place to avoid possible large copy of data
for column in data.columns:
if isinstance(column, int):
self._original_columns = data.columns
data.columns = data.columns.astype(str)
return True

return False

def preprocess(self, data):
"""Transform the raw data to numerical space.
Expand All @@ -384,7 +395,14 @@ def preprocess(self, data):
"please refit the model using 'fit' or 'fit_processed_data'."
)

return self._preprocess(data)
is_converted = self._store_and_convert_original_cols(data)

preprocess_data = self._preprocess(data)

if is_converted:
data.columns = self._original_columns

return preprocess_data

def _fit(self, processed_data):
"""Fit the model to the table.
Expand Down Expand Up @@ -455,7 +473,7 @@ def fit(self, data):
self._fitted = False
self._data_processor.reset_sampling()
self._random_state_set = False
processed_data = self._preprocess(data)
processed_data = self.preprocess(data)
self.fit_processed_data(processed_data)

def save(self, filepath):
Expand Down Expand Up @@ -891,6 +909,10 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
show_progress_bar=show_progress_bar
)

original_columns = getattr(self, '_original_columns', pd.Index([]))
if not original_columns.empty:
sampled_data.columns = self._original_columns

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
Expand Down
2 changes: 1 addition & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def unflatten_dict(flat):

else:
subdict = unflattened.setdefault(key, {})
if subkey.isdigit():
if subkey.isdigit() and key != 'univariates':
subkey = int(subkey)

inner = subdict.setdefault(subkey, {})
Expand Down
50 changes: 50 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,3 +1679,53 @@ def test_hma_not_fit_raises_sampling_error():
)
with pytest.raises(SamplingError, match=error_msg):
synthesizer.sample(1)


def test_fit_and_sample_numerical_col_names():
"""Test fitting/sampling when column names are integers"""
# Setup
num_rows = 50
num_cols = 10
num_tables = 2
data = {}
for i in range(num_tables):
values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)}
data[str(i)] = pd.DataFrame(values)

primary_key = pd.DataFrame({1: range(num_rows)})
primary_key_2 = pd.DataFrame({2: range(num_rows)})
data['0'][1] = primary_key
data['1'][1] = primary_key
data['1'][2] = primary_key_2
metadata = MultiTableMetadata()
metadata_dict = {'tables': {}}
for table_idx in range(num_tables):
metadata_dict['tables'][str(table_idx)] = {'columns': {}}
for i in range(num_cols):
metadata_dict['tables'][str(table_idx)]['columns'][i] = {'sdtype': 'numerical'}
metadata_dict['tables']['0']['columns'][1] = {'sdtype': 'id'}
metadata_dict['tables']['1']['columns'][2] = {'sdtype': 'id'}
metadata_dict['relationships'] = [
{
'parent_table_name': '0',
'parent_primary_key': 1,
'child_table_name': '1',
'child_foreign_key': 2
}
]
metadata = MultiTableMetadata.load_from_dict(metadata_dict)
metadata.set_primary_key('0', '1')

# Run
synth = HMASynthesizer(metadata)
synth.fit(data)
first_sample = synth.sample()
second_sample = synth.sample()
assert first_sample['0'].columns.tolist() == data['0'].columns.tolist()
assert first_sample['1'].columns.tolist() == data['1'].columns.tolist()
assert second_sample['0'].columns.tolist() == data['0'].columns.tolist()
assert second_sample['1'].columns.tolist() == data['1'].columns.tolist()

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])
38 changes: 38 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,44 @@ def test_fit_raises_version_error():
instance.fit(data)


SYNTHESIZERS_CLASSES = [
pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'),
pytest.param(TVAESynthesizer, id='TVAESynthesizer'),
pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'),
pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'),
]


@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_fit_and_sample_numerical_col_names(synthesizer_class):
"""Test fitting/sampling when column names are integers"""
# Setup
num_rows = 50
num_cols = 10
values = {
i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)
}
data = pd.DataFrame(values)
metadata = SingleTableMetadata()
metadata_dict = {'columns': {}}
for i in range(num_cols):
metadata_dict['columns'][i] = {'sdtype': 'numerical'}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)

# Run
synth = synthesizer_class(metadata)
synth.fit(data)
sample_1 = synth.sample(10)
sample_2 = synth.sample(10)

assert sample_1.columns.tolist() == data.columns.tolist()
assert sample_2.columns.tolist() == data.columns.tolist()

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(sample_1, sample_2)


@pytest.mark.parametrize('synthesizer', SYNTHESIZERS)
def test_sample_not_fitted(synthesizer):
"""Test that a synthesizer raises an error when trying to sample without fitting."""
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,85 @@ def test_load_from_dict(self, mock_singletablemetadata):
}
]

@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_load_from_dict_integer(self, mock_singletablemetadata):
"""Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``.
Test that when calling the ``load_from_dict`` method a new instance with the passed
python ``dict`` details should be created. Make sure that integers passed in are
turned into strings to ensure metadata is properly typed.
Setup:
- A dict representing a ``MultiTableMetadata``.
Mock:
- Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table``
Output:
- ``instance`` that contains ``instance.tables`` and ``instance.relationships``.
Side Effects:
- ``SingleTableMetadata.load_from_dict`` has been called.
"""
# Setup
multitable_metadata = {
'tables': {
'accounts': {
1: {'sdtype': 'numerical'},
2: {'sdtype': 'numerical'},
'amount': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'owner': {'sdtype': 'id'},
},
'branches': {
1: {'sdtype': 'numerical'},
'name': {'sdtype': 'id'},
}
},
'relationships': [
{
'parent_table_name': 'accounts',
'parent_primary_key': 1,
'child_table_name': 'branches',
'child_foreign_key': 1,
}
]
}

single_table_accounts = {
'1': {'sdtype': 'numerical'},
'2': {'sdtype': 'numerical'},
'amount': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'owner': {'sdtype': 'id'},
}
single_table_branches = {
'1': {'sdtype': 'numerical'},
'name': {'sdtype': 'id'},
}
mock_singletablemetadata.load_from_dict.side_effect = [
single_table_accounts,
single_table_branches
]

# Run
instance = MultiTableMetadata.load_from_dict(multitable_metadata)

# Assert
assert instance.tables == {
'accounts': single_table_accounts,
'branches': single_table_branches
}

assert instance.relationships == [
{
'parent_table_name': 'accounts',
'parent_primary_key': '1',
'child_table_name': 'branches',
'child_foreign_key': '1',
}
]

@patch('sdv.metadata.multi_table.json')
def test___repr__(self, mock_json):
"""Test that the ``__repr__`` method.
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,34 @@ def test_load_from_dict(self):
assert instance.sequence_index is None
assert instance._version == 'SINGLE_TABLE_V1'

def test_load_from_dict_integer(self):
"""Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects.
If the metadata dict contains columns with integers for certain reasons
(e.g. due to missing column names from CSV) make sure they are correctly typed
to strings to ensure metadata is parsed properly.
"""
# Setup
my_metadata = {
'columns': {1: 'value'},
'primary_key': 'pk',
'alternate_keys': [],
'sequence_key': None,
'sequence_index': None,
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'
}

# Run
instance = SingleTableMetadata.load_from_dict(my_metadata)

# Assert
assert instance.columns == {'1': 'value'}
assert instance.primary_key == 'pk'
assert instance.sequence_key is None
assert instance.alternate_keys == []
assert instance.sequence_index is None
assert instance._version == 'SINGLE_TABLE_V1'

@patch('sdv.metadata.utils.Path')
def test_load_from_json_path_does_not_exist(self, mock_path):
"""Test the ``load_from_json`` method.
Expand Down

0 comments on commit e1f787e

Please sign in to comment.