Skip to content

Commit

Permalink
add warning null foreign keys
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Apr 26, 2024
1 parent 142e807 commit b93e71d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
14 changes: 12 additions & 2 deletions sdv/multi_table/utils.py
@@ -1,13 +1,14 @@
"""Utility functions for the MultiTable models."""
import math
import warnings
from collections import defaultdict
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _get_root_tables
from sdv.errors import InvalidDataError, SamplingError
from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS

Expand Down Expand Up @@ -632,6 +633,15 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
result = deepcopy(data)
primary_keys_referenced = _get_primary_keys_referenced(result, metadata)
ratio_to_keep = num_rows / len(result[main_table_name])
try:
_validate_foreign_keys_not_null(metadata, result)
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
' to drop these rows before using ``get_random_subset``.'
)

try:
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/utils/test_poc.py
Expand Up @@ -335,9 +335,15 @@ def test_get_random_subset_with_missing_values(metadata, data):
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = np.nan
expected_warning = re.escape(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
' to drop these rows before using ``get_random_subset``.'
)

# Run
cleaned_data = get_random_subset(data, metadata, 'child', 3)
with pytest.warns(UserWarning, match=expected_warning):
cleaned_data = get_random_subset(data, metadata, 'child', 3)

# Assert
assert len(cleaned_data['child']) == 3
Expand Down
11 changes: 9 additions & 2 deletions tests/unit/multi_table/test_utils.py
Expand Up @@ -1891,7 +1891,9 @@ def test__subsample_ancestors_schema_diamond_shape():
@patch('sdv.multi_table.utils._subsample_ancestors')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
@patch('sdv.multi_table.utils._drop_rows')
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data(
mock_validate_foreign_keys_not_null,
mock_drop_rows,
mock_get_primary_keys_referenced,
mock_subsample_ancestors,
Expand All @@ -1915,6 +1917,7 @@ def test__subsample_data(
result = _subsample_data(data, metadata, main_table, num_rows)

# Assert
mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data)
mock_get_primary_keys_referenced.assert_called_once_with(data, metadata)
mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5)
mock_subsample_table_and_descendants.assert_called_once_with(
Expand All @@ -1929,8 +1932,12 @@ def test__subsample_data(

@patch('sdv.multi_table.utils._subsample_disconnected_roots')
@patch('sdv.multi_table.utils._get_primary_keys_referenced')
def test__subsample_data_empty_dataset(mock_get_primary_keys_referenced,
mock_subsample_disconnected_roots):
@patch('sdv.multi_table.utils._validate_foreign_keys_not_null')
def test__subsample_data_empty_dataset(
mock_validate_foreign_keys_not_null,
mock_get_primary_keys_referenced,
mock_subsample_disconnected_roots
):
"""Test the ``subsample_data`` method when a dataset is empty."""
# Setup
data = {
Expand Down

0 comments on commit b93e71d

Please sign in to comment.