Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_random_subset poc utility function #1928

Merged
merged 14 commits into from
Apr 30, 2024
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ dependencies = [
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0,<2;python_version>='3.12'",
"pandas>=1.4.0;python_version<'3.10'",
"pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'",
"pandas>=1.4.0;python_version<'3.11'",
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
"pandas>=2.1.1;python_version>='3.12'",
'tqdm>=4.29',
Expand Down
80 changes: 64 additions & 16 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Utility functions for the MultiTable models."""
import math
import warnings
from collections import defaultdict
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _get_root_tables
from sdv.errors import InvalidDataError, SamplingError
from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS

Expand Down Expand Up @@ -429,16 +430,29 @@ def _get_rows_to_drop(data, metadata):
return table_to_idx_to_drop


def _get_idx_to_drop_nan_foreign_key_table(data, relationships, table):
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
"""Get the indexes of the rows to drop that have NaN foreign keys."""
idx_with_nan_foreign_key = set()
relationships_for_table = _get_relationships_for_child(relationships, table)
for relationship in relationships_for_table:
child_column = relationship['child_foreign_key']
idx_with_nan_foreign_key.update(
data[table][data[table][child_column].isna()].index
)

return idx_with_nan_foreign_key


def _drop_rows(data, metadata, drop_missing_values):
table_to_idx_to_drop = _get_rows_to_drop(data, metadata)
for table in sorted(metadata.tables):
idx_to_drop = table_to_idx_to_drop[table]
data[table] = data[table].drop(idx_to_drop)
if drop_missing_values:
relationships = _get_relationships_for_child(metadata.relationships, table)
for relationship in relationships:
child_column = relationship['child_foreign_key']
data[table] = data[table].dropna(subset=[child_column])
idx_with_nan_fk = _get_idx_to_drop_nan_foreign_key_table(
data, metadata.relationships, table
)
data[table] = data[table].drop(idx_with_nan_fk)

if data[table].empty:
raise InvalidDataError([
Expand All @@ -454,13 +468,37 @@ def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep):
for root in roots:
data[root] = data[root].sample(frac=ratio_to_keep)

_drop_rows(data, metadata, drop_missing_values=False)
_drop_rows(data, metadata, drop_missing_values=True)


def _subsample_table_and_descendants(data, metadata, table, num_rows):
"""Subsample the table and its descendants."""
"""Subsample the table and its descendants.

The logic is to first subsample all the NaN foreign key of the table.
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
We raise an error if we cannot reach referential integrity while keeping the number of rows.
Then, we drop rows of the descendants to ensure referential integrity.

Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
table (str):
Name of the table.
"""
idx_nan_fk = _get_idx_to_drop_nan_foreign_key_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)

data[table] = data[table].sample(num_rows)
_drop_rows(data, metadata, drop_missing_values=False)
_drop_rows(data, metadata, drop_missing_values=True)


def _get_primary_keys_referenced(data, metadata):
Expand Down Expand Up @@ -489,7 +527,7 @@ def _get_primary_keys_referenced(data, metadata):


def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_before,
unreferenced_pk_parent):
dereferenced_pk_parent):
"""Subsample the parent table.

The strategy here is to:
Expand All @@ -503,18 +541,18 @@ def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_bef
Name of the primary key of the parent table.
parent_pk_referenced_before (set):
Set of the primary keys referenced before any subsampling.
unreferenced_pk_parent (set):
dereferenced_pk_parent (set):
Set of the primary keys that are no longer referenced by the descendants.

Returns:
pandas.DataFrame:
Subsampled parent table.
"""
total_referenced = len(parent_pk_referenced_before)
total_dropped = len(unreferenced_pk_parent)
total_dropped = len(dereferenced_pk_parent)
drop_proportion = total_dropped / total_referenced

parent_table = parent_table[~parent_table[parent_primary_key].isin(unreferenced_pk_parent)]
parent_table = parent_table[~parent_table[parent_primary_key].isin(dereferenced_pk_parent)]
unreferenced_data = parent_table[
~parent_table[parent_primary_key].isin(parent_pk_referenced_before)
]
Expand Down Expand Up @@ -556,12 +594,12 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
for parent in sorted(direct_parents):
parent_primary_key = metadata.tables[parent].primary_key
pk_referenced_before = primary_keys_referenced[parent]
unreferenced_primary_keys = pk_referenced_before - pk_referenced[parent]
dereferenced_primary_keys = pk_referenced_before - pk_referenced[parent]
data[parent] = _subsample_parent(
data[parent], parent_primary_key, pk_referenced_before,
unreferenced_primary_keys
dereferenced_primary_keys
)
if unreferenced_primary_keys:
if dereferenced_primary_keys:
primary_keys_referenced[parent] = pk_referenced[parent]

_subsample_ancestors(data, metadata, parent, primary_keys_referenced)
Expand Down Expand Up @@ -595,10 +633,20 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
result = deepcopy(data)
primary_keys_referenced = _get_primary_keys_referenced(result, metadata)
ratio_to_keep = num_rows / len(result[main_table_name])
try:
_validate_foreign_keys_not_null(metadata, result)
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
' to drop these rows before using ``get_random_subset``.'
)

try:
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
_subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced)
_drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys
except InvalidDataError as error:
if 'All references in table' not in str(error.args[0]):
raise error
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,23 @@ def test_get_random_subset_disconnected_schema():
# Assert
assert len(result['Player']) == num_rows_to_keep
assert len(result['Team']) == int(len(real_data['Team']) * proportion_to_keep)


def test_get_random_subset_with_missing_values(metadata, data):
"""Test ``get_random_subset`` when there is missing values in the foreign keys."""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = np.nan
expected_warning = re.escape(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
' to drop these rows before using ``get_random_subset``.'
)

# Run
with pytest.warns(UserWarning, match=expected_warning):
cleaned_data = get_random_subset(data, metadata, 'child', 3)

# Assert
assert len(cleaned_data['child']) == 3
assert not pd.isna(cleaned_data['child']['parent_id']).any()
103 changes: 91 additions & 12 deletions tests/unit/multi_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from sdv.metadata import MultiTableMetadata
from sdv.multi_table.utils import (
_drop_rows, _get_all_descendant_per_root_at_order_n, _get_ancestors,
_get_columns_to_drop_child, _get_disconnected_roots_from_table, _get_n_order_descendants,
_get_num_column_to_drop, _get_primary_keys_referenced, _get_relationships_for_child,
_get_relationships_for_parent, _get_rows_to_drop, _get_total_estimated_columns,
_print_simplified_schema_summary, _print_subsample_summary, _simplify_child,
_simplify_children, _simplify_data, _simplify_grandchildren, _simplify_metadata,
_simplify_relationships_and_tables, _subsample_ancestors, _subsample_data,
_subsample_disconnected_roots, _subsample_parent, _subsample_table_and_descendants)
_get_columns_to_drop_child, _get_disconnected_roots_from_table,
_get_idx_to_drop_nan_foreign_key_table, _get_n_order_descendants, _get_num_column_to_drop,
_get_primary_keys_referenced, _get_relationships_for_child, _get_relationships_for_parent,
_get_rows_to_drop, _get_total_estimated_columns, _print_simplified_schema_summary,
_print_subsample_summary, _simplify_child, _simplify_children, _simplify_data,
_simplify_grandchildren, _simplify_metadata, _simplify_relationships_and_tables,
_subsample_ancestors, _subsample_data, _subsample_disconnected_roots, _subsample_parent,
_subsample_table_and_descendants)


def test__get_relationships_for_child():
Expand Down Expand Up @@ -133,6 +134,44 @@ def test__get_rows_to_drop():
assert result == expected_result


def test__get_idx_to_drop_nan_foreign_key_table():
"""Test the ``_get_idx_to_drop_nan_foreign_key_table`` method."""
# Setup
relationships = [
{
'parent_table_name': 'parent',
'child_table_name': 'child',
'parent_primary_key': 'id_parent',
'child_foreign_key': 'parent_foreign_key'
},
{
'parent_table_name': 'child',
'child_table_name': 'grandchild',
'parent_primary_key': 'id_child',
'child_foreign_key': 'child_foreign_key'
},
{
'parent_table_name': 'parent',
'child_table_name': 'grandchild',
'parent_primary_key': 'id_parent',
'child_foreign_key': 'parent_foreign_key'
}
]
data = {
'grandchild': pd.DataFrame({
'parent_foreign_key': [np.nan, 1, 2, 2, np.nan],
'child_foreign_key': [9, np.nan, 11, 6, 4],
'C': ['Yes', 'No', 'No', 'No', 'No']
})
}

# Run
result = _get_idx_to_drop_nan_foreign_key_table(data, relationships, 'grandchild')

# Assert
assert result == {0, 1, 4}


@patch('sdv.multi_table.utils._get_rows_to_drop')
def test__drop_rows(mock_get_rows_to_drop):
"""Test the ``_drop_rows`` method."""
Expand Down Expand Up @@ -1301,7 +1340,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo
mock_get_disconnected_roots_from_table.assert_called_once_with(
metadata.relationships, 'disconnected_root'
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=False)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
for table_name in metadata.tables:
if table_name not in {'grandparent', 'other_root'}:
pd.testing.assert_frame_equal(data[table_name], expected_result[table_name])
Expand All @@ -1310,7 +1349,9 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo


@patch('sdv.multi_table.utils._drop_rows')
def test__subsample_table_and_descendants(mock_drop_rows):
@patch('sdv.multi_table.utils._get_idx_to_drop_nan_foreign_key_table')
def test__subsample_table_and_descendants(mock_get_idx_to_drop_nan_foreign_key_table,
mock_drop_rows):
"""Test the ``_subsample_table_and_descendants`` method."""
# Setup
data = {
Expand All @@ -1331,16 +1372,44 @@ def test__subsample_table_and_descendants(mock_drop_rows):
'col_8': [6, 7, 8, 9, 10],
}),
}
mock_get_idx_to_drop_nan_foreign_key_table.return_value = {0}
metadata = Mock()
metadata.relationships = Mock()

# Run
_subsample_table_and_descendants(data, metadata, 'parent', 3)

# Assert
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=False)
mock_get_idx_to_drop_nan_foreign_key_table.assert_called_once_with(
data, metadata.relationships, 'parent'
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
assert len(data['parent']) == 3


@patch('sdv.multi_table.utils._get_idx_to_drop_nan_foreign_key_table')
def test__subsample_table_and_descendants_nan_fk(mock_get_idx_to_drop_nan_foreign_key_table):
"""Test the ``_subsample_table_and_descendants`` when there is too many NaN foreign keys."""
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
# Setup
data = {'parent': [1, 2, 3, 4, 5, 6]}
mock_get_idx_to_drop_nan_foreign_key_table.return_value = {0, 1, 2, 3, 4}
metadata = Mock()
metadata.relationships = Mock()
expected_message = re.escape(
"Referential integrity cannot be reached for table 'parent' while keeping "
'3 rows. Please try again with a bigger number of rows.'
)

# Run
with pytest.raises(SamplingError, match=expected_message):
_subsample_table_and_descendants(data, metadata, 'parent', 3)

# Assert
mock_get_idx_to_drop_nan_foreign_key_table.assert_called_once_with(
data, metadata.relationships, 'parent'
)


def test__get_primary_keys_referenced():
"""Test the ``_get_primary_keys_referenced`` method."""
data = {
Expand Down Expand Up @@ -1821,7 +1890,11 @@ def test__subsample_ancestors_schema_diamond_shape():
@patch('sdv.multi_table.utils._subsample_table_and_descendants')
@patch('sdv.multi_table.utils._subsample_ancestors')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
@patch('sdv.multi_table.utils._drop_rows')
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data(
mock_validate_foreign_keys_not_null,
mock_drop_rows,
mock_get_primary_keys_referenced,
mock_subsample_ancestors,
mock_subsample_table_and_descendants,
Expand All @@ -1844,6 +1917,7 @@ def test__subsample_data(
result = _subsample_data(data, metadata, main_table, num_rows)

# Assert
mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data)
mock_get_primary_keys_referenced.assert_called_once_with(data, metadata)
mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5)
mock_subsample_table_and_descendants.assert_called_once_with(
Expand All @@ -1852,13 +1926,18 @@ def test__subsample_data(
mock_subsample_ancestors.assert_called_once_with(
data, metadata, main_table, primary_key_reference
)
mock_drop_rows.assert_called_once_with(data, metadata, drop_missing_values=True)
assert result == data


@patch('sdv.multi_table.utils._subsample_disconnected_roots')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
def test__subsample_data_empty_dataset(mock_get_primary_keys_referenced,
mock_subsample_disconnected_roots):
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data_empty_dataset(
mock_validate_foreign_keys_not_null,
mock_get_primary_keys_referenced,
mock_subsample_disconnected_roots
):
"""Test the ``subsample_data`` method when a dataset is empty."""
# Setup
data = {
Expand Down