From b93e71de01680ce151d1f14395299e926dd734b7 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 26 Apr 2024 11:38:54 +0100 Subject: [PATCH] add warning null foreign keys --- sdv/multi_table/utils.py | 14 ++++++++++++-- tests/integration/utils/test_poc.py | 8 +++++++- tests/unit/multi_table/test_utils.py | 11 +++++++++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 994f108c6..8a35bf526 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -1,13 +1,14 @@ """Utility functions for the MultiTable models.""" import math +import warnings from collections import defaultdict from copy import deepcopy import numpy as np import pandas as pd -from sdv._utils import _get_root_tables -from sdv.errors import InvalidDataError, SamplingError +from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null +from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError from sdv.multi_table import HMASynthesizer from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS @@ -632,6 +633,15 @@ def _subsample_data(data, metadata, main_table_name, num_rows): result = deepcopy(data) primary_keys_referenced = _get_primary_keys_referenced(result, metadata) ratio_to_keep = num_rows / len(result[main_table_name]) + try: + _validate_foreign_keys_not_null(metadata, result) + except SynthesizerInputError: + warnings.warn( + 'The data contains null values in foreign key columns. ' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + ' to drop these rows before using ``get_random_subset``.' + ) + try: _subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep) _subsample_table_and_descendants(result, metadata, main_table_name, num_rows) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 91164ff5d..7aec94243 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -335,9 +335,15 @@ def test_get_random_subset_with_missing_values(metadata, data): # Setup data = deepcopy(data) data['child'].loc[4, 'parent_id'] = np.nan + expected_warning = re.escape( + 'The data contains null values in foreign key columns. ' + 'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc' + ' to drop these rows before using ``get_random_subset``.' + ) # Run - cleaned_data = get_random_subset(data, metadata, 'child', 3) + with pytest.warns(UserWarning, match=expected_warning): + cleaned_data = get_random_subset(data, metadata, 'child', 3) # Assert assert len(cleaned_data['child']) == 3 diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 1557b61d2..9437a2c31 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1891,7 +1891,9 @@ def test__subsample_ancestors_schema_diamond_shape(): @patch('sdv.multi_table.utils._subsample_ancestors') @patch('sdv.multi_table.utils._get_primary_keys_referenced') @patch('sdv.multi_table.utils._drop_rows') +@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') def test__subsample_data( + mock_validate_foreign_keys_not_null, mock_drop_rows, mock_get_primary_keys_referenced, mock_subsample_ancestors, @@ -1915,6 +1917,7 @@ def test__subsample_data( result = _subsample_data(data, metadata, main_table, num_rows) # Assert + mock_validate_foreign_keys_not_null.assert_called_once_with(metadata, data) mock_get_primary_keys_referenced.assert_called_once_with(data, metadata) mock_subsample_disconnected_roots.assert_called_once_with(data, metadata, main_table, 0.5) mock_subsample_table_and_descendants.assert_called_once_with( @@ -1929,8 +1932,12 @@ def test__subsample_data( @patch('sdv.multi_table.utils._subsample_disconnected_roots') @patch('sdv.multi_table.utils._get_primary_keys_referenced') -def test__subsample_data_empty_dataset(mock_get_primary_keys_referenced, - mock_subsample_disconnected_roots): +@patch('sdv.multi_table.utils._validate_foreign_keys_not_null') +def test__subsample_data_empty_dataset( + mock_validate_foreign_keys_not_null, + mock_get_primary_keys_referenced, + mock_subsample_disconnected_roots +): """Test the ``subsample_data`` method when a dataset is empty.""" # Setup data = {