Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include sdv_logger_config.yml with the package #1963

Merged
merged 6 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
'botocore>=1.31',
'cloudpickle>=2.1.0',
'graphviz>=0.13.2',
"numpy>=1.20.0;python_version<'3.10'",
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0,<2;python_version>='3.12'",
"pandas>=1.1.3;python_version<'3.10'",
Expand Down Expand Up @@ -141,7 +141,8 @@ namespaces = false
'make.bat',
'*.jpg',
'*.png',
'*.gif'
'*.gif',
'sdv_logger_config.yml'
]

[tool.setuptools.exclude-package-data]
Expand Down
3 changes: 0 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _initialize_models(self):
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
table_name=table_name,
**synthesizer_parameters
)

Expand Down Expand Up @@ -200,8 +199,6 @@ def set_table_parameters(self, table_name, table_parameters):
A dictionary with the parameters as keys and the values to be used to instantiate
the table's synthesizer.
"""
# Ensure that we set the name of the table no matter what
table_parameters.update({'table_name': table_name})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=self.metadata.tables[table_name],
**table_parameters
Expand Down
12 changes: 3 additions & 9 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[child_name]
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
Expand Down Expand Up @@ -523,11 +521,9 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[child_name]
)
synthesizer._set_parameters(parameters, default_parameters)
synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor
Expand Down Expand Up @@ -622,11 +618,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
for parent_id, row in parent_rows.iterrows():
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._table_synthesizers[table_name].get_metadata()
synthesizer_parameters = self._table_parameters[table_name]
synthesizer_parameters.update({'table_name': table_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[table_name]
)
synthesizer._set_parameters(parameters)
try:
Expand Down
2 changes: 1 addition & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales
locales=locales,
)

sequence_key = self.metadata.sequence_key
Expand Down
23 changes: 10 additions & 13 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,19 @@ def _check_metadata_updated(self):
)

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
locales=['en_US'], table_name=None):
locales=['en_US']):
self._validate_inputs(enforce_min_max_values, enforce_rounding)
self.metadata = metadata
self.metadata.validate()
self._check_metadata_updated()
self.enforce_min_max_values = enforce_min_max_values
self.enforce_rounding = enforce_rounding
self.locales = locales
self.table_name = table_name
self._data_processor = DataProcessor(
metadata=self.metadata,
enforce_rounding=self.enforce_rounding,
enforce_min_max_values=self.enforce_min_max_values,
locales=self.locales,
table_name=self.table_name
)
self._fitted = False
self._random_state_set = False
Expand Down Expand Up @@ -500,16 +498,15 @@ def load(cls, filepath):
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)

if synthesizer.table_name is None:
SYNTHESIZER_LOGGER.info(
'\nLoad:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
synthesizer.__class__.__name__,
synthesizer._synthesizer_id,
)
SYNTHESIZER_LOGGER.info(
'\nLoad:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
synthesizer.__class__.__name__,
synthesizer._synthesizer_id,
)

return synthesizer

Expand Down
4 changes: 1 addition & 3 deletions sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6,
discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500,
discriminator_steps=1, log_frequency=True, verbose=False, epochs=300,
pac=10, cuda=True, numerical_distributions=None, default_distribution=None,
table_name=None):
pac=10, cuda=True, numerical_distributions=None, default_distribution=None):

super().__init__(
metadata,
Expand All @@ -143,7 +142,6 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
epochs=epochs,
pac=pac,
cuda=cuda,
table_name=table_name
)

validate_numerical_distributions(numerical_distributions, self.metadata.columns)
Expand Down
4 changes: 1 addition & 3 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,12 @@ def get_distribution_class(cls, distribution):
return cls._DISTRIBUTIONS[distribution]

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
locales=['en_US'], numerical_distributions=None, default_distribution=None,
table_name=None):
locales=['en_US'], numerical_distributions=None, default_distribution=None):
super().__init__(
metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales,
table_name=table_name
)
validate_numerical_distributions(numerical_distributions, self.metadata.columns)
self.numerical_distributions = numerical_distributions or {}
Expand Down
7 changes: 2 additions & 5 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6,
discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500,
discriminator_steps=1, log_frequency=True, verbose=False, epochs=300,
pac=10, cuda=True, table_name=None):
pac=10, cuda=True):

super().__init__(
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales,
table_name=table_name
)

self.embedding_dim = embedding_dim
Expand Down Expand Up @@ -339,14 +338,12 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128),
l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True,
table_name=None):
l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True):

super().__init__(
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
table_name=table_name
)
self.embedding_dim = embedding_dim
self.compress_dims = compress_dims
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'characters'
'numerical_distributions': {}
}
families_params = hmasynthesizer.get_table_parameters('families')
assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
Expand All @@ -161,8 +160,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'families'
'numerical_distributions': {}
}
char_families_params = hmasynthesizer.get_table_parameters('character_families')
assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
Expand All @@ -171,8 +169,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'character_families'
'numerical_distributions': {}
}

assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma'
Expand Down
15 changes: 4 additions & 11 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def test__initialize_models(self):
}
instance._synthesizer.assert_has_calls([
call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma',
locales=locales, table_name='nesreca'),
call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'),
call(metadata=instance.metadata.tables['upravna_enota'], locales=locales,
table_name='upravna_enota')
locales=locales),
call(metadata=instance.metadata.tables['oseba'], locales=locales),
call(metadata=instance.metadata.tables['upravna_enota'], locales=locales)
])

def test__get_pbar_args(self):
Expand Down Expand Up @@ -280,7 +279,6 @@ def test_get_table_parameters_empty(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'table_name': 'oseba',
'numerical_distributions': {}
}
}
Expand All @@ -301,7 +299,6 @@ def test_get_table_parameters_has_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'table_name': 'oseba',
'numerical_distributions': {}
}

Expand Down Expand Up @@ -333,17 +330,13 @@ def test_set_table_parameters(self):

# Assert
table_parameters = instance.get_table_parameters('oseba')
assert instance._table_parameters['oseba'] == {
'default_distribution': 'gamma',
'table_name': 'oseba'
}
assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'}
assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert table_parameters['synthesizer_parameters'] == {
'default_distribution': 'gamma',
'enforce_min_max_values': True,
'locales': ['en_US'],
'enforce_rounding': True,
'table_name': 'oseba',
'numerical_distributions': {}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self):
# Assert
assert synthesizer == instance._synthesizer.return_value
assert synthesizer._data_processor == table_synthesizer._data_processor
instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1)
instance._synthesizer.assert_called_once_with(table_meta, a=1)
synthesizer._set_parameters.assert_called_once_with(
instance._extract_parameters.return_value,
{'colA': 'default_param', 'colB': 'default_param'}
Expand Down
19 changes: 5 additions & 14 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor,
metadata=metadata,
enforce_rounding=instance.enforce_rounding,
enforce_min_max_values=instance.enforce_min_max_values,
locales=instance.locales,
table_name=None
locales=instance.locales
)
metadata.validate.assert_called_once_with()
mock_check_metadata_updated.assert_called_once()
Expand Down Expand Up @@ -124,8 +123,7 @@ def test___init__custom(self, mock_data_processor):
metadata=metadata,
enforce_rounding=instance.enforce_rounding,
enforce_min_max_values=instance.enforce_min_max_values,
locales=instance.locales,
table_name=None
locales=instance.locales
)
metadata.validate.assert_called_once_with()

Expand Down Expand Up @@ -184,8 +182,7 @@ def test_get_parameters(self, mock_data_processor):
assert parameters == {
'enforce_min_max_values': False,
'enforce_rounding': False,
'locales': 'en_CA',
'table_name': None
'locales': 'en_CA'
}

@patch('sdv.single_table.base.DataProcessor')
Expand Down Expand Up @@ -362,8 +359,7 @@ def test_fit_processed_data(self, mock_datetime, caplog):
instance = Mock(
_fitted_sdv_version=None,
_fitted_sdv_enterprise_version=None,
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
)
processed_data = pd.DataFrame({'column_a': [1, 2, 3]})

Expand Down Expand Up @@ -394,7 +390,6 @@ def test_fit_processed_data_raises_version_error(self):
instance = Mock(
_fitted_sdv_version='1.0.0',
_fitted_sdv_enterprise_version=None,
table_name=None
)
processed_data = pd.DataFrame({'column_a': [1, 2, 3]})
instance._random_state_set = True
Expand Down Expand Up @@ -422,7 +417,6 @@ def test_fit(self, mock_datetime, caplog):
_fitted_sdv_version=None,
_fitted_sdv_enterprise_version=None,
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']})
instance._random_state_set = True
Expand Down Expand Up @@ -459,7 +453,6 @@ def test_fit_raises_version_error(self):
instance = Mock(
_fitted_sdv_version='1.0.0',
_fitted_sdv_enterprise_version=None,
table_name=None
)
data = pd.DataFrame({'column_a': [1, 2, 3]})
instance._random_state_set = True
Expand Down Expand Up @@ -1417,7 +1410,6 @@ def test_sample(self, mock_datetime, caplog):
output_file_path = 'temp.csv'
instance = Mock(
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
instance.get_metadata.return_value._constraints = False
instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]})
Expand Down Expand Up @@ -1810,7 +1802,6 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog):
# Setup
synthesizer = Mock(
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'

Expand Down Expand Up @@ -1839,7 +1830,7 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war
mock_datetime, caplog):
"""Test that the ``load`` method loads a stored synthesizer."""
# Setup
synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None)
synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None)
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_synthesizer_id.return_value = synthesizer_id
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/single_table/test_copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def test_get_params(self):
'pac': 10,
'cuda': True,
'numerical_distributions': {},
'default_distribution': 'beta',
'table_name': None
'default_distribution': 'beta'
}

@patch('sdv.single_table.copulagan.rdt')
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/single_table/test_copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def test_get_parameters(self):
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'default_distribution': 'beta',
'table_name': None
'default_distribution': 'beta'
}

@patch('sdv.single_table.copulas.LOGGER')
Expand Down