Skip to content

Commit

Permalink
Add tests and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed May 8, 2024
1 parent b2066bf commit 4f73e0d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
10 changes: 7 additions & 3 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def update_column(self, table_name, column_name, **kwargs):
"""
self._validate_table_exists(table_name)
table = self.tables.get(table_name)
table.update_column(column_name, **kwargs)
table.update_column(str(column_name), **kwargs)

def update_columns(self, table_name, column_names, **kwargs):
"""Update multiple columns with the same metadata kwargs.
Expand Down Expand Up @@ -523,7 +523,9 @@ def _detect_relationships(self):

def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.
All data columns are converted to strings
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
for a specified table. All data columns are converted to strings.
Args:
table_name (str):
Expand All @@ -539,7 +541,9 @@ def detect_table_from_dataframe(self, table_name, data):

def detect_from_dataframes(self, data):
"""Detect the metadata for all tables in a dictionary of dataframes.
All data columns are converted to strings
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data columns are converted to strings.
Args:
data (dict):
Expand Down
5 changes: 3 additions & 2 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data columns are converted to strings
All data columns are converted to strings.
Args:
data (pandas.DataFrame):
Expand Down Expand Up @@ -1236,7 +1236,8 @@ def load_from_dict(cls, metadata_dict):
Python dictionary representing a ``SingleTableMetadata`` object.
Returns:
Instance of ``SingleTableMetadata``.
Instance of ``SingleTableMetadata``. Column names are converted to
string type.
"""
instance = cls()
for key in instance._KEYS:
Expand Down
33 changes: 27 additions & 6 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,7 @@ def test_fit_and_sample_numerical_col_names():

def test_detect_from_dataframe_numerical_col():
"""Test that metadata detection of integer columns work."""
# Setup
parent_data = pd.DataFrame({
1: [1000, 1001, 1002],
2: [2, 3, 4],
Expand All @@ -1742,28 +1743,48 @@ def test_detect_from_dataframe_numerical_col():
3: [1000, 1001, 1000],
4: [1, 2, 3]
})

data = {
'parent_data': parent_data,
'child_data': child_data,
}

metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('parent_data', parent_data)
metadata.update_column('parent_data', '1', sdtype='id')
metadata.detect_table_from_dataframe('child_data', child_data)
metadata.update_column('parent_data', '1', sdtype='id')
metadata.update_column('child_data', '3', sdtype='id')
metadata.update_column('child_data', '4', sdtype='id')
metadata.set_primary_key('parent_data', '1')
metadata.set_primary_key('child_data', '4')
metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

# test_metadata = MultiTableMetadata()
# test_metadata.detect_from_dataframes(data)
test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)
test_metadata.update_column('parent_data', '1', sdtype='id')
test_metadata.update_column('child_data', '3', sdtype='id')
test_metadata.update_column('child_data', '4', sdtype='id')
test_metadata.set_primary_key('parent_data', '1')
test_metadata.set_primary_key('child_data', '4')
test_metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

# Run
instance = HMASynthesizer(metadata)
instance.fit(data)
sample = instance.sample(5)
assert sample.columns.tolist() == data.columns.tolist()

# Assert
assert test_metadata.to_dict() == metadata.to_dict()
assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist()
assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist()

test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)

0 comments on commit 4f73e0d

Please sign in to comment.