Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 14, 2024
1 parent 0840c7f commit e3e3d9e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _estimate_columns_traversal(cls, metadata, table_name,
columns_per_table[table_name] += \
cls._get_num_extended_columns(
metadata, child_name, table_name, columns_per_table, distributions
)
)

visited.add(table_name)

Expand Down
4 changes: 2 additions & 2 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _get_n_order_descendants(relationships, parent_table, order):
descendants = {}
order_1_descendants = _get_relationships_for_parent(relationships, parent_table)
descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants]
for i in range(2, order+1):
for i in range(2, order + 1):
descendants[f'order_{i}'] = []
prov_descendants = []
for child_table in descendants[f'order_{i-1}']:
for child_table in descendants[f'order_{i - 1}']:
order_i_descendants = _get_relationships_for_parent(relationships, child_table)
prov_descendants.extend([rel['child_table_name'] for rel in order_i_descendants])

Expand Down
2 changes: 1 addition & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _transform_sequence_index(self, data):
fill_value = min(sequence_index_sequence[self._sequence_index].dropna())
sequence_index_sequence = sequence_index_sequence.fillna(fill_value)

data[self._sequence_index] = sequence_index_sequence[self._sequence_index]
data[self._sequence_index] = sequence_index_sequence[self._sequence_index].to_numpy()
data = data.merge(
sequence_index_context,
left_on=self._sequence_key,
Expand Down
57 changes: 57 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime

import numpy as np
import pandas as pd
from deepecho import load_demo

Expand Down Expand Up @@ -192,3 +193,59 @@ def test_sythesize_sequences(tmp_path):
synthesizer.validate(loaded_sample)
loaded_synthesizer.validate(synthetic_data)
loaded_synthesizer.validate(loaded_sample)


def test_par_subset_of_data():
"""Test it when the data index is not continuous GH#1973."""
# download data
data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019',)

# modify the data by choosing a subset of it
data_subset = data.copy()
np.random.seed(1234)
symbols = data['Symbol'].unique()

# only select a subset of data in each sequence
for i, symbol in enumerate(symbols):
symbol_mask = data_subset['Symbol'] == symbol
data_subset = data_subset.drop(
data_subset[symbol_mask].sample(frac=i / (2 * len(symbols))).index)

# now run PAR
synthesizer = PARSynthesizer(metadata, epochs=5, verbose=True)
synthesizer.fit(data_subset)
synthetic_data = synthesizer.sample(num_sequences=5)

# assert that the synthetic data doesn't contain NaN values in sequence index column
assert not pd.isnull(synthetic_data['Date']).any()


def test_par_subset_of_data_simplified():
"""Test it when the data index is not continuous for a simple dataset GH#1973."""
# Setup
data = pd.DataFrame({
'id': [1, 2, 3],
'date': ['2020-01-01', '2020-01-02', '2020-01-03'],
})
data.index = [0, 1, 5]
metadata = SingleTableMetadata.load_from_dict({
'sequence_index': 'date',
'sequence_key': 'id',
'columns': {
'id': {
'sdtype': 'id',
},
'date': {
'sdtype': 'datetime',
},
},
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'
})
synthesizer = PARSynthesizer(metadata, epochs=0)

# Run
synthesizer.fit(data)
synthetic_data = synthesizer.sample(num_sequences=50)

# Assert
assert not pd.isnull(synthetic_data['date']).any()

0 comments on commit e3e3d9e

Please sign in to comment.