Skip to content

Commit

Permalink
add nan foreign key logic
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Apr 26, 2024
1 parent b4b2fd4 commit 142e807
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 27 deletions.
66 changes: 52 additions & 14 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,29 @@ def _get_rows_to_drop(data, metadata):
return table_to_idx_to_drop


def _get_idx_to_drop_nan_foreign_key_table(data, relationships, table):
"""Get the indexes of the rows to drop that have NaN foreign keys."""
idx_with_nan_foreign_key = set()
relationships_for_table = _get_relationships_for_child(relationships, table)
for relationship in relationships_for_table:
child_column = relationship['child_foreign_key']
idx_with_nan_foreign_key.update(
data[table][data[table][child_column].isna()].index
)

return idx_with_nan_foreign_key


def _drop_rows(data, metadata, drop_missing_values):
table_to_idx_to_drop = _get_rows_to_drop(data, metadata)
for table in sorted(metadata.tables):
idx_to_drop = table_to_idx_to_drop[table]
data[table] = data[table].drop(idx_to_drop)
if drop_missing_values:
relationships = _get_relationships_for_child(metadata.relationships, table)
for relationship in relationships:
child_column = relationship['child_foreign_key']
data[table] = data[table].dropna(subset=[child_column])
idx_with_nan_fk = _get_idx_to_drop_nan_foreign_key_table(
data, metadata.relationships, table
)
data[table] = data[table].drop(idx_with_nan_fk)

if data[table].empty:
raise InvalidDataError([
Expand All @@ -454,13 +467,37 @@ def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep):
for root in roots:
data[root] = data[root].sample(frac=ratio_to_keep)

_drop_rows(data, metadata, drop_missing_values=False)
_drop_rows(data, metadata, drop_missing_values=True)


def _subsample_table_and_descendants(data, metadata, table, num_rows):
"""Subsample the table and its descendants."""
"""Subsample the table and its descendants.
The logic is to first subsample all the NaN foreign key of the table.
We raise an error if we cannot reach referential integrity while keeping the number of rows.
Then, we drop rows of the descendants to ensure referential integrity.
Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
table (str):
Name of the table.
"""
idx_nan_fk = _get_idx_to_drop_nan_foreign_key_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)

data[table] = data[table].sample(num_rows)
_drop_rows(data, metadata, drop_missing_values=False)
_drop_rows(data, metadata, drop_missing_values=True)


def _get_primary_keys_referenced(data, metadata):
Expand Down Expand Up @@ -489,7 +526,7 @@ def _get_primary_keys_referenced(data, metadata):


def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_before,
unreferenced_pk_parent):
dereferenced_pk_parent):
"""Subsample the parent table.
The strategy here is to:
Expand All @@ -503,18 +540,18 @@ def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_bef
Name of the primary key of the parent table.
parent_pk_referenced_before (set):
Set of the primary keys referenced before any subsampling.
unreferenced_pk_parent (set):
dereferenced_pk_parent (set):
Set of the primary keys that are no longer referenced by the descendants.
Returns:
pandas.DataFrame:
Subsampled parent table.
"""
total_referenced = len(parent_pk_referenced_before)
total_dropped = len(unreferenced_pk_parent)
total_dropped = len(dereferenced_pk_parent)
drop_proportion = total_dropped / total_referenced

parent_table = parent_table[~parent_table[parent_primary_key].isin(unreferenced_pk_parent)]
parent_table = parent_table[~parent_table[parent_primary_key].isin(dereferenced_pk_parent)]
unreferenced_data = parent_table[
~parent_table[parent_primary_key].isin(parent_pk_referenced_before)
]
Expand Down Expand Up @@ -556,12 +593,12 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
for parent in sorted(direct_parents):
parent_primary_key = metadata.tables[parent].primary_key
pk_referenced_before = primary_keys_referenced[parent]
unreferenced_primary_keys = pk_referenced_before - pk_referenced[parent]
dereferenced_primary_keys = pk_referenced_before - pk_referenced[parent]
data[parent] = _subsample_parent(
data[parent], parent_primary_key, pk_referenced_before,
unreferenced_primary_keys
dereferenced_primary_keys
)
if unreferenced_primary_keys:
if dereferenced_primary_keys:
primary_keys_referenced[parent] = pk_referenced[parent]

_subsample_ancestors(data, metadata, parent, primary_keys_referenced)
Expand Down Expand Up @@ -599,6 +636,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
_subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced)
_drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys
except InvalidDataError as error:
if 'All references in table' not in str(error.args[0]):
raise error
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,17 @@ def test_get_random_subset_disconnected_schema():
# Assert
assert len(result['Player']) == num_rows_to_keep
assert len(result['Team']) == int(len(real_data['Team']) * proportion_to_keep)


def test_get_random_subset_with_missing_values(metadata, data):
"""Test ``get_random_subset`` when there is missing values in the foreign keys."""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = np.nan

# Run
cleaned_data = get_random_subset(data, metadata, 'child', 3)

# Assert
assert len(cleaned_data['child']) == 3
assert not pd.isna(cleaned_data['child']['parent_id']).any()
92 changes: 82 additions & 10 deletions tests/unit/multi_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from sdv.metadata import MultiTableMetadata
from sdv.multi_table.utils import (
_drop_rows, _get_all_descendant_per_root_at_order_n, _get_ancestors,
_get_columns_to_drop_child, _get_disconnected_roots_from_table, _get_n_order_descendants,
_get_num_column_to_drop, _get_primary_keys_referenced, _get_relationships_for_child,
_get_relationships_for_parent, _get_rows_to_drop, _get_total_estimated_columns,
_print_simplified_schema_summary, _print_subsample_summary, _simplify_child,
_simplify_children, _simplify_data, _simplify_grandchildren, _simplify_metadata,
_simplify_relationships_and_tables, _subsample_ancestors, _subsample_data,
_subsample_disconnected_roots, _subsample_parent, _subsample_table_and_descendants)
_get_columns_to_drop_child, _get_disconnected_roots_from_table,
_get_idx_to_drop_nan_foreign_key_table, _get_n_order_descendants, _get_num_column_to_drop,
_get_primary_keys_referenced, _get_relationships_for_child, _get_relationships_for_parent,
_get_rows_to_drop, _get_total_estimated_columns, _print_simplified_schema_summary,
_print_subsample_summary, _simplify_child, _simplify_children, _simplify_data,
_simplify_grandchildren, _simplify_metadata, _simplify_relationships_and_tables,
_subsample_ancestors, _subsample_data, _subsample_disconnected_roots, _subsample_parent,
_subsample_table_and_descendants)


def test__get_relationships_for_child():
Expand Down Expand Up @@ -133,6 +134,44 @@ def test__get_rows_to_drop():
assert result == expected_result


def test__get_idx_to_drop_nan_foreign_key_table():
"""Test the ``_get_idx_to_drop_nan_foreign_key_table`` method."""
# Setup
relationships = [
{
'parent_table_name': 'parent',
'child_table_name': 'child',
'parent_primary_key': 'id_parent',
'child_foreign_key': 'parent_foreign_key'
},
{
'parent_table_name': 'child',
'child_table_name': 'grandchild',
'parent_primary_key': 'id_child',
'child_foreign_key': 'child_foreign_key'
},
{
'parent_table_name': 'parent',
'child_table_name': 'grandchild',
'parent_primary_key': 'id_parent',
'child_foreign_key': 'parent_foreign_key'
}
]
data = {
'grandchild': pd.DataFrame({
'parent_foreign_key': [np.nan, 1, 2, 2, np.nan],
'child_foreign_key': [9, np.nan, 11, 6, 4],
'C': ['Yes', 'No', 'No', 'No', 'No']
})
}

# Run
result = _get_idx_to_drop_nan_foreign_key_table(data, relationships, 'grandchild')

# Assert
assert result == {0, 1, 4}


@patch('sdv.multi_table.utils._get_rows_to_drop')
def test__drop_rows(mock_get_rows_to_drop):
"""Test the ``_drop_rows`` method."""
Expand Down Expand Up @@ -1301,7 +1340,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo
mock_get_disconnected_roots_from_table.assert_called_once_with(
metadata.relationships, 'disconnected_root'
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=False)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
for table_name in metadata.tables:
if table_name not in {'grandparent', 'other_root'}:
pd.testing.assert_frame_equal(data[table_name], expected_result[table_name])
Expand All @@ -1310,7 +1349,9 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo


@patch('sdv.multi_table.utils._drop_rows')
def test__subsample_table_and_descendants(mock_drop_rows):
@patch('sdv.multi_table.utils._get_idx_to_drop_nan_foreign_key_table')
def test__subsample_table_and_descendants(mock_get_idx_to_drop_nan_foreign_key_table,
mock_drop_rows):
"""Test the ``_subsample_table_and_descendants`` method."""
# Setup
data = {
Expand All @@ -1331,16 +1372,44 @@ def test__subsample_table_and_descendants(mock_drop_rows):
'col_8': [6, 7, 8, 9, 10],
}),
}
mock_get_idx_to_drop_nan_foreign_key_table.return_value = {0}
metadata = Mock()
metadata.relationships = Mock()

# Run
_subsample_table_and_descendants(data, metadata, 'parent', 3)

# Assert
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=False)
mock_get_idx_to_drop_nan_foreign_key_table.assert_called_once_with(
data, metadata.relationships, 'parent'
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
assert len(data['parent']) == 3


@patch('sdv.multi_table.utils._get_idx_to_drop_nan_foreign_key_table')
def test__subsample_table_and_descendants_nan_fk(mock_get_idx_to_drop_nan_foreign_key_table):
"""Test the ``_subsample_table_and_descendants`` when there is too many NaN foreign keys."""
# Setup
data = {'parent': [1, 2, 3, 4, 5, 6]}
mock_get_idx_to_drop_nan_foreign_key_table.return_value = {0, 1, 2, 3, 4}
metadata = Mock()
metadata.relationships = Mock()
expected_message = re.escape(
"Referential integrity cannot be reached for table 'parent' while keeping "
'3 rows. Please try again with a bigger number of rows.'
)

# Run
with pytest.raises(SamplingError, match=expected_message):
_subsample_table_and_descendants(data, metadata, 'parent', 3)

# Assert
mock_get_idx_to_drop_nan_foreign_key_table.assert_called_once_with(
data, metadata.relationships, 'parent'
)


def test__get_primary_keys_referenced():
"""Test the ``_get_primary_keys_referenced`` method."""
data = {
Expand Down Expand Up @@ -1821,7 +1890,9 @@ def test__subsample_ancestors_schema_diamond_shape():
@patch('sdv.multi_table.utils._subsample_table_and_descendants')
@patch('sdv.multi_table.utils._subsample_ancestors')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
@patch('sdv.multi_table.utils._drop_rows')
def test__subsample_data(
mock_drop_rows,
mock_get_primary_keys_referenced,
mock_subsample_ancestors,
mock_subsample_table_and_descendants,
Expand Down Expand Up @@ -1852,6 +1923,7 @@ def test__subsample_data(
mock_subsample_ancestors.assert_called_once_with(
data, metadata, main_table, primary_key_reference
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
assert result == data


Expand Down
7 changes: 4 additions & 3 deletions tests/unit/utils/test_poc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from collections import defaultdict
from unittest.mock import Mock, patch

import numpy as np
Expand Down Expand Up @@ -136,7 +137,7 @@ def test_drop_unknown_references_valid_data_mock(mock_stdout_write):
pd.testing.assert_frame_equal(table, data[table_name])


@patch('sdv.utils.poc._get_rows_to_drop')
@patch('sdv.multi_table.utils._get_rows_to_drop')
@patch('sdv.utils.poc._validate_foreign_keys_not_null')
def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_rows_to_drop):
"""Test ``drop_unknown_references`` whith NaNs and drop_missing_values True."""
Expand Down Expand Up @@ -217,7 +218,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r
pd.testing.assert_frame_equal(table, expected_result[table_name])


@patch('sdv.utils.poc._get_rows_to_drop')
@patch('sdv.multi_table.utils._get_rows_to_drop')
def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop):
"""Test ``drop_unknown_references`` with NaNs and drop_missing_values False."""
# Setup
Expand Down Expand Up @@ -293,7 +294,7 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop
pd.testing.assert_frame_equal(table, expected_result[table_name])


@patch('sdv.utils.poc._get_rows_to_drop')
@patch('sdv.multi_table.utils._get_rows_to_drop')
def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop):
"""Test ``drop_unknown_references`` when all rows are dropped."""
# Setup
Expand Down

0 comments on commit 142e807

Please sign in to comment.