Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_random_subset poc utility function #1928

Merged
merged 14 commits into from
Apr 30, 2024
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ dependencies = [
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0,<2;python_version>='3.12'",
"pandas>=1.4.0;python_version<'3.10'",
"pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'",
"pandas>=1.4.0;python_version<'3.11'",
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
"pandas>=2.1.1;python_version>='3.12'",
'tqdm>=4.29',
Expand Down
7 changes: 0 additions & 7 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,6 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self._synthesizer_id
)

def _get_root_parents(self):
"""Get the set of root parents in the graph."""
non_root_tables = set(self.metadata._get_parent_map().keys())
root_parents = set(self.metadata.tables.keys()) - non_root_tables

return root_parents

def set_address_columns(self, table_name, column_names, anonymization_level='full'):
"""Set the address multi-column transformer.

Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def get_learned_distributions(self, table_name):
Dictionary containing the distributions used or detected for each column and the
learned parameters for those.
"""
if table_name not in self._get_root_parents():
if table_name not in _get_root_tables(self.metadata.relationships):
raise SynthesizerInputError(
f"Learned distributions are not available for the '{table_name}' table. "
'Please choose a table that does not have any parents.'
Expand Down
285 changes: 283 additions & 2 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
"""Utility functions for the MultiTable models."""
import math
import warnings
from collections import defaultdict
from copy import deepcopy

import numpy as np
import pandas as pd

from sdv._utils import _get_root_tables
from sdv._utils import _get_root_tables, _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.multi_table import HMASynthesizer
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS

MODELABLE_SDTYPE = ['categorical', 'numerical', 'datetime', 'boolean']


def _get_child_tables(relationships):
gsheni marked this conversation as resolved.
Show resolved Hide resolved
parent_tables = {rel['parent_table_name'] for rel in relationships}
child_tables = {rel['child_table_name'] for rel in relationships}
return child_tables - parent_tables


def _get_relationships_for_child(relationships, child_table):
return [rel for rel in relationships if rel['child_table_name'] == child_table]

Expand Down Expand Up @@ -79,6 +87,34 @@ def _get_all_descendant_per_root_at_order_n(relationships, order):
return all_descendants


def _get_ancestors(relationships, child_table):
"""Get the ancestors of the child table."""
ancestors = set()
parent_relationships = _get_relationships_for_child(relationships, child_table)
for relationship in parent_relationships:
parent_table = relationship['parent_table_name']
ancestors.add(parent_table)
ancestors.update(_get_ancestors(relationships, parent_table))

return ancestors


def _get_disconnected_roots_from_table(relationships, table):
"""Get the disconnected roots table from the given table."""
root_tables = _get_root_tables(relationships)
child_tables = _get_child_tables(relationships)
if table in child_tables:
return root_tables - _get_ancestors(relationships, table)

connected_roots = set()
for child in child_tables:
child_ancestor = _get_ancestors(relationships, child)
if table in child_ancestor:
connected_roots.update(root_tables.intersection(child_ancestor))

return root_tables - connected_roots


def _simplify_relationships_and_tables(metadata, tables_to_drop):
"""Simplify the relationships and tables of the metadata.

Expand Down Expand Up @@ -339,7 +375,7 @@ def _print_simplified_schema_summary(data_before, data_after):
print('\n'.join(message)) # noqa: T001


def _get_rows_to_drop(metadata, data):
def _get_rows_to_drop(data, metadata):
"""Get the rows to drop to ensure referential integrity.

The logic of this function is to start at the root tables, look at invalid references
Expand Down Expand Up @@ -392,3 +428,248 @@ def _get_rows_to_drop(metadata, data):
relationships = [rel for rel in relationships if rel not in relationships_parent]

return table_to_idx_to_drop


def _get_nan_fk_indices_table(data, relationships, table):
"""Get the indexes of the rows to drop that have NaN foreign keys."""
idx_with_nan_foreign_key = set()
relationships_for_table = _get_relationships_for_child(relationships, table)
for relationship in relationships_for_table:
child_column = relationship['child_foreign_key']
idx_with_nan_foreign_key.update(
data[table][data[table][child_column].isna()].index
)

return idx_with_nan_foreign_key


def _drop_rows(data, metadata, drop_missing_values):
table_to_idx_to_drop = _get_rows_to_drop(data, metadata)
for table in sorted(metadata.tables):
idx_to_drop = table_to_idx_to_drop[table]
data[table] = data[table].drop(idx_to_drop)
if drop_missing_values:
idx_with_nan_fk = _get_nan_fk_indices_table(
data, metadata.relationships, table
)
data[table] = data[table].drop(idx_with_nan_fk)

if data[table].empty:
raise InvalidDataError([
f"All references in table '{table}' are unknown and must be dropped."
'Try providing different data for this table.'
])


def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep):
"""Subsample the disconnected roots tables and their descendants."""
relationships = metadata.relationships
roots = _get_disconnected_roots_from_table(relationships, table)
for root in roots:
data[root] = data[root].sample(frac=ratio_to_keep)

_drop_rows(data, metadata, drop_missing_values=True)


def _subsample_table_and_descendants(data, metadata, table, num_rows):
"""Subsample the table and its descendants.

The logic is to first subsample all the NaN foreign keys of the table.
We raise an error if we cannot reach referential integrity while keeping the number of rows.
Then, we drop rows of the descendants to ensure referential integrity.

Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
table (str):
Name of the table.
"""
idx_nan_fk = _get_nan_fk_indices_table(data, metadata.relationships, table)
num_rows_to_drop = len(data[table]) - num_rows
if len(idx_nan_fk) > num_rows_to_drop:
raise SamplingError(
f"Referential integrity cannot be reached for table '{table}' while keeping "
f'{num_rows} rows. Please try again with a bigger number of rows.'
)
else:
data[table] = data[table].drop(idx_nan_fk)

data[table] = data[table].sample(num_rows)
_drop_rows(data, metadata, drop_missing_values=True)


def _get_primary_keys_referenced(data, metadata):
"""Get the primary keys referenced by the relationships.

Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.

Returns:
dict:
Dictionary that maps the table name to a set of their primary keys referenced.
"""
relationships = metadata.relationships
primary_keys_referenced = defaultdict(set)
for relationship in relationships:
parent_table = relationship['parent_table_name']
child_table = relationship['child_table_name']
foreign_key = relationship['child_foreign_key']
primary_keys_referenced[parent_table].update(set(data[child_table][foreign_key].unique()))

return primary_keys_referenced


def _subsample_parent(parent_table, parent_primary_key, parent_pk_referenced_before,
dereferenced_pk_parent):
"""Subsample the parent table.

The strategy here is to:
- Drop the rows that are no longer referenced by the descendants.
- Drop a proportional amount of never-referenced rows.

Args:
parent_table (pandas.DataFrame):
Parent table to subsample.
parent_primary_key (str):
Name of the primary key of the parent table.
parent_pk_referenced_before (set):
Set of the primary keys referenced before any subsampling.
dereferenced_pk_parent (set):
Set of the primary keys that are no longer referenced by the descendants.

Returns:
pandas.DataFrame:
Subsampled parent table.
"""
total_referenced = len(parent_pk_referenced_before)
total_dropped = len(dereferenced_pk_parent)
drop_proportion = total_dropped / total_referenced

parent_table = parent_table[~parent_table[parent_primary_key].isin(dereferenced_pk_parent)]
unreferenced_data = parent_table[
~parent_table[parent_primary_key].isin(parent_pk_referenced_before)
]

# Randomly drop a proportional amount of never-referenced rows
unreferenced_data_to_drop = unreferenced_data.sample(frac=drop_proportion)
parent_table = parent_table.drop(unreferenced_data_to_drop.index)
if parent_table.empty:
raise InvalidDataError([
f"All references in table '{parent_primary_key}' are unknown and must be dropped."
'Try providing different data for this table.'
])

return parent_table


def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
"""Subsample the ancestors of the table.

The strategy here is to recursively subsample the direct parents of the table until the
root tables are reached.

Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
table (str):
Name of the table.
primary_keys_referenced (dict):
Dictionary that maps the table name to a set of their primary keys referenced
before any subsampling.
"""
relationships = metadata.relationships
pk_referenced = _get_primary_keys_referenced(data, metadata)
direct_relationships = _get_relationships_for_child(relationships, table)
direct_parents = {rel['parent_table_name'] for rel in direct_relationships}
for parent in sorted(direct_parents):
parent_primary_key = metadata.tables[parent].primary_key
pk_referenced_before = primary_keys_referenced[parent]
dereferenced_primary_keys = pk_referenced_before - pk_referenced[parent]
data[parent] = _subsample_parent(
data[parent], parent_primary_key, pk_referenced_before,
dereferenced_primary_keys
)
if dereferenced_primary_keys:
primary_keys_referenced[parent] = pk_referenced[parent]

_subsample_ancestors(data, metadata, parent, primary_keys_referenced)


def _subsample_data(data, metadata, main_table_name, num_rows):
"""Subsample multi-table table based on a table and a number of rows.

The strategy is to:
- Subsample the disconnected roots tables by keeping a similar proportion of data
than the main table. Ensure referential integrity.
- Subsample the main table and its descendants to ensure referential integrity.
- Subsample the ancestors of the main table by removing primary key rows that are no longer
referenced by the descendants and some unreferenced rows.

Args:
metadata (MultiTableMetadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
main_table_name (str):
Name of the main table.
num_rows (int):
Number of rows to keep in the main table.

Returns:
dict:
Dictionary with the subsampled dataframes.
"""
result = deepcopy(data)
primary_keys_referenced = _get_primary_keys_referenced(result, metadata)
ratio_to_keep = num_rows / len(result[main_table_name])
try:
_validate_foreign_keys_not_null(metadata, result)
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
' to drop these rows before using ``get_random_subset``.'
)

try:
_subsample_disconnected_roots(result, metadata, main_table_name, ratio_to_keep)
_subsample_table_and_descendants(result, metadata, main_table_name, num_rows)
_subsample_ancestors(result, metadata, main_table_name, primary_keys_referenced)
_drop_rows(result, metadata, drop_missing_values=True) # Drop remaining NaN foreign keys
except InvalidDataError as error:
if 'All references in table' not in str(error.args[0]):
raise error
else:
raise SamplingError(
f'Subsampling {main_table_name} with {num_rows} rows leads to some empty tables. '
'Please try again with a bigger number of rows.'
)

return result


def _print_subsample_summary(data_before, data_after):
"""Print the summary of the subsampled data."""
tables = sorted(data_before.keys())
summary = pd.DataFrame({
'Table Name': tables,
'# Rows (Before)': [len(data_before[table]) for table in tables],
'# Rows (After)': [
len(data_after[table]) if table in data_after else 0 for table in tables
]
})
subsample_rows = 100 * (1 - summary['# Rows (After)'].sum() / summary['# Rows (Before)'].sum())
message = [f'Success! Your subset has {round(subsample_rows)}% less rows than the original.\n']
message.append(summary.to_string(index=False))
print('\n'.join(message)) # noqa: T001