Skip to content

Commit

Permalink
make release-tag: Merge branch 'main' into stable
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Nov 7, 2023
2 parents 808fc94 + 55668dc commit d8c8101
Show file tree
Hide file tree
Showing 42 changed files with 1,435 additions and 430 deletions.
31 changes: 30 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
# Release Notes

## 1.6.0 - 2023-11-07

This release improves user messaging in multiple ways. The most notable is that users will now see an alert if the `HMASynthesizer` is likely to be slow for their data's schema. Additionally, the logger messaging for constraints and the error messaging when setting distributions on non-parametric models was made more detailed.

The visualization plots in the `sdv.evaluation` sub-package all got a new parameter called `plot_type`, allowing the users to specify the plot type to use if the one being inferred is not useful. The `sdv.datasets.local.load_csvs` method now has a parameter called `read_csv_parameters`, that allow users to specify how the csvs should be read during loading. The same change was also made to the `sdv.metadata.multi_table.detect_table_from_csv`, `sdv.metadata.multi_table.detect_from_csvs` and `sdv.metadata.single_table.detect_from_csv` methods.

Multiple bugs were resolved including one that caused new categories to be created during the sample step of `CTGANSynthesizer`.

### New Features

* Improve debug messages when a constraint falls back to reject sampling approach - Issue [#1478](https://github.com/sdv-dev/SDV/issues/1478) by @amontanez24
* Constraints should work with timezone-aware datetime columns - Issue [#1576](https://github.com/sdv-dev/SDV/issues/1576) by @fealho
* Better error message when trying to get distributions from non-parametric models - PR [#1633](https://github.com/sdv-dev/SDV/pull/1633) by @frances-h
* Add options to read CSV files - Issue [#1644](https://github.com/sdv-dev/SDV/issues/1644) by @lajohn4747
* Print alert if HMASynthesizer is likely to be slow - Issue [#1646](https://github.com/sdv-dev/SDV/issues/1646) by @lajohn4747
* Make SDV compatible with SDMetrics 0.12.1 - Issue [#1650](https://github.com/sdv-dev/SDV/issues/1650) by @pvk-developer
* Make function to estimate number of columns CTGAN produces - Issue [#1657](https://github.com/sdv-dev/SDV/issues/1657) by @fealho

### Bugs Fixed

* In get_available_demos, the num_tables column should be an int - Issue [#1420](https://github.com/sdv-dev/SDV/issues/1420) by @lajohn4747
* AttributeError when using specific locale strings (es_AR, fr_BE) - Issue [#1439](https://github.com/sdv-dev/SDV/issues/1439) by @lajohn4747
* Confusing error when passing in an empty dataframe (with constraints) - Issue [#1455](https://github.com/sdv-dev/SDV/issues/1455) by @lajohn4747
* HMASynthesizer: Better error message for learned distributions (misleading fit error) - Issue [#1579](https://github.com/sdv-dev/SDV/issues/1579) by @fealho
* Fix tests in SDV after update in RDT 1.7.1 - Issue [#1638](https://github.com/sdv-dev/SDV/issues/1638) by @lajohn4747
* CTGAN sometimes creates new categories (int data) - Issue [#1647](https://github.com/sdv-dev/SDV/issues/1647) by @pvk-developer
* CTGAN sometimes creates new categories (object data) - Issue [#1648](https://github.com/sdv-dev/SDV/issues/1648) by @pvk-developer
* Better error message if I provide an incompatible sdtype/locale combo - Issue [#1653](https://github.com/sdv-dev/SDV/issues/1653) by @pvk-developer

## 1.5.0 - 2023-10-13

Several improvements and bug fixes were made in this release. Most notably, the metadata detection was substantially improved. Support for the 'unknown' sdtype was added, providing more flexibility in data representation. The software now attempts to intelligently detect primary keys and identify parent-child relationships in the metadata, streamlining the metadata creation process.
Several improvements and bug fixes were made in this release. Most notably, the metadata detection was substantially improved. Support for the 'unknown' sdtype was added, providing more flexibility in data representation. The software now attempts to intelligently detect primary keys and identify parent-child relationships in the metadata, streamlining the metadata creation process.

Additionally, issues related to conditional sampling with negative float values, the inability to update transformers for columns created by constraints, and compatibility with numpy version 1.25 and higher were addressed. The default branch was also switched from 'master' to 'main' for better development practices. Various bugs and errors, including those involving HMA and datetime format detection, were also resolved.

Expand Down
2 changes: 1 addition & 1 deletion sdv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

__author__ = 'DataCebo, Inc.'
__email__ = 'info@sdv.dev'
__version__ = '1.5.0'
__version__ = '1.6.0.dev2'


import sys
Expand Down
8 changes: 8 additions & 0 deletions sdv/constraints/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Constraint Exceptions."""

import logging

from sdv.errors import log_exc_stacktrace

LOGGER = logging.getLogger(__name__)


class MissingConstraintColumnError(Exception):
"""Error used when constraint is provided a table with missing columns."""
Expand All @@ -13,6 +19,8 @@ class AggregateConstraintsError(Exception):

def __init__(self, errors):
self.errors = errors
for error in self.errors:
log_exc_stacktrace(LOGGER, error)

def __str__(self):
return '\n' + '\n\n'.join(map(str, self.errors))
Expand Down
14 changes: 7 additions & 7 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def transform(self, data):
except InvalidFunctionError as e:
raise e

except Exception:
raise FunctionError
except Exception as e:
raise FunctionError(str(e))

def reverse_transform(self, data):
"""Reverse transform the table data.
Expand Down Expand Up @@ -455,7 +455,7 @@ def is_valid(self, table_data):
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)

valid = np.isnan(low) | np.isnan(high) | self._operator(high, low)
valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
return valid

def _transform(self, table_data):
Expand Down Expand Up @@ -660,7 +660,7 @@ def is_valid(self, table_data):
if self._is_datetime and self._dtype == 'O':
column = cast_to_datetime64(column)

valid = np.isnan(column) | self._operator(column, self._value)
valid = pd.isna(column) | self._operator(column, self._value)
return valid

def _transform(self, table_data):
Expand Down Expand Up @@ -1124,16 +1124,16 @@ def is_valid(self, table_data):

satisfy_low_bound = np.logical_or(
self._operator(self._low_value, data),
np.isnan(self._low_value),
pd.isna(self._low_value),
)
satisfy_high_bound = np.logical_or(
self._operator(data, self._high_value),
np.isnan(self._high_value),
pd.isna(self._high_value),
)

return np.logical_or(
np.logical_and(satisfy_low_bound, satisfy_high_bound),
np.isnan(data),
pd.isna(data),
)

def _transform(self, table_data):
Expand Down
2 changes: 1 addition & 1 deletion sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_datetime_diff(high, low, dtype='O'):
high = cast_to_datetime64(high)

diff_column = high - low
nan_mask = np.isnan(diff_column)
nan_mask = pd.isna(diff_column)
diff_column = diff_column.astype(np.float64)
diff_column[nan_mask] = np.nan
return diff_column
40 changes: 30 additions & 10 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import rdt
from pandas.api.types import is_float_dtype, is_integer_dtype
from rdt.transformers import AnonymizedFaker, IDGenerator, RegexGenerator, get_default_transformers
from rdt.transformers.pii.anonymization import get_anonymized_transformer

from sdv.constraints import Constraint
from sdv.constraints.base import get_subclasses
Expand All @@ -19,8 +20,7 @@
from sdv.data_processing.errors import InvalidConstraintsError, NotFittedError
from sdv.data_processing.numerical_formatter import NumericalFormatter
from sdv.data_processing.utils import load_module_from_path
from sdv.errors import SynthesizerInputError
from sdv.metadata.anonymization import get_anonymized_transformer
from sdv.errors import SynthesizerInputError, log_exc_stacktrace
from sdv.metadata.single_table import SingleTableMetadata

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -316,15 +316,23 @@ def _transform_constraints(self, data, is_condition=False):
except (MissingConstraintColumnError, FunctionError) as error:
if isinstance(error, MissingConstraintColumnError):
LOGGER.info(
f'{constraint.__class__.__name__} cannot be transformed because columns: '
f'{error.missing_columns} were not found. Using the reject sampling '
'approach instead.'
'Unable to transform %s with columns %s because they are not all available'
' in the data. This happens due to multiple, overlapping constraints.',
constraint.__class__.__name__,
error.missing_columns
)
log_exc_stacktrace(LOGGER, error)
else:
# Error came from custom constraint. We don't want to crash but we do
# want to log it.
LOGGER.info(
f'Error transforming {constraint.__class__.__name__}. '
'Using the reject sampling approach instead.'
'Unable to transform %s with columns %s due to an error in transform: \n'
'%s\nUsing the reject sampling approach instead.',
constraint.__class__.__name__,
constraint.column_names,
str(error)
)
log_exc_stacktrace(LOGGER, error)
if is_condition:
indices_to_drop = data.columns.isin(constraint.constraint_columns)
columns_to_drop = data.columns.where(indices_to_drop).dropna()
Expand Down Expand Up @@ -371,7 +379,15 @@ def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness, l
if enforce_uniqueness:
kwargs['enforce_uniqueness'] = True

return get_anonymized_transformer(sdtype, kwargs)
try:
transformer = get_anonymized_transformer(sdtype, kwargs)
except AttributeError as error:
raise SynthesizerInputError(
f"The sdtype '{sdtype}' is not compatible with any of the locales. To "
"continue, try changing the locales or adding 'en_US' as a possible option."
) from error

return transformer

def create_regex_generator(self, column_name, sdtype, column_metadata, is_numeric):
"""Create a ``RegexGenerator`` for the ``id`` columns.
Expand Down Expand Up @@ -554,8 +570,7 @@ def _fit_hyper_transformer(self, data):
Returns:
rdt.HyperTransformer
"""
if not data.empty:
self._hyper_transformer.fit(data)
self._hyper_transformer.fit(data)

def _fit_formatters(self, data):
"""Fit ``NumericalFormatter`` and ``DatetimeFormatter`` for each column in the data."""
Expand Down Expand Up @@ -631,9 +646,14 @@ def fit(self, data):
data (pandas.DataFrame):
Table to be analyzed.
"""
if data.empty:
raise ValueError('The fit dataframe is empty, synthesizer will not be fitted.')
self._prepared_for_fitting = False
self.prepare_for_fitting(data)
constrained = self._transform_constraints(data)
if constrained.empty:
raise ValueError(
'The constrained fit dataframe is empty, synthesizer will not be fitted.')
LOGGER.info(f'Fitting HyperTransformer for table {self.table_name}')
self._fit_hyper_transformer(constrained)
self.fitted = True
Expand Down
4 changes: 3 additions & 1 deletion sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,6 @@ def get_available_demos(modality):
tables_info['size_MB'].append(round(float(size_mb), 2))
tables_info['num_tables'].append(headers.get('num-tables', np.nan))

return pd.DataFrame(tables_info)
df = pd.DataFrame(tables_info)
df['num_tables'] = pd.to_numeric(df['num_tables'])
return df
7 changes: 5 additions & 2 deletions sdv/datasets/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from sdv.utils import load_data_from_csv


def load_csvs(folder_name):
def load_csvs(folder_name, read_csv_parameters=None):
"""Load csv files from specified folder.
Args:
folder_name (str):
The full path of the folder with the data to be loaded.
read_csv_parameters (dict):
A python dictionary of with string and value accepted by ``pandas.read_csv``
function. Defaults to ``None``.
"""
if not path.exists(folder_name):
raise ValueError(f"The folder '{folder_name}' cannot be found.")
Expand All @@ -23,7 +26,7 @@ def load_csvs(folder_name):
base_name, ext = path.splitext(filename)
if ext == '.csv':
filepath = path.join(dirpath, filename)
csvs[base_name] = load_data_from_csv(filepath)
csvs[base_name] = load_data_from_csv(filepath, read_csv_parameters)
else:
other_files.append(filename)

Expand Down
22 changes: 22 additions & 0 deletions sdv/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
"""SDV Exceptions."""

import logging
import traceback

LOGGER = logging.getLogger(__name__)


def log_exc_stacktrace(logger, error):
"""Log the stack trace of an exception.
Args:
logger (logging.Logger):
A logger object to use for the logging.
error (Exception):
The error to log.
"""
message = ''.join(traceback.format_exception(type(error), error, error.__traceback__))
logger.debug(message)


class NotFittedError(Exception):
"""Error to raise when sample is called and the model is not fitted."""
Expand Down Expand Up @@ -36,3 +54,7 @@ def __str__(self):
'The provided data does not match the metadata:\n' +
'\n\n'.join(map(str, self.errors))
)


class VisualizationUnavailableError(Exception):
"""Exception to indicate that a visualization is unavailable."""
59 changes: 46 additions & 13 deletions sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Methods to compare the real and synthetic data for multi-table."""
import sdmetrics.reports.utils as report
from sdmetrics import visualization
from sdmetrics.reports.multi_table.diagnostic_report import DiagnosticReport
from sdmetrics.reports.multi_table.quality_report import QualityReport

import sdv.evaluation.single_table as single_table_visualization


def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
"""Evaluate the quality of the synthetic data.
Expand Down Expand Up @@ -50,7 +52,7 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True):
return diagnostic_report


def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name):
def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name, plot_type=None):
"""Get a plot of the real and synthetic data for a given column.
Args:
Expand All @@ -64,18 +66,29 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name
The name of the table.
column_name (str):
The name of the column.
plot_type (str or None):
The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'. If
``None``, select between 'bar' or displot depending on the data.
Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure:
1D marginal distribution plot (i.e. a histogram) of the columns.
"""
metadata = metadata.to_dict()['tables'][table_name]
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
return report.get_column_plot(real_data, synthetic_data, column_name, metadata)
return single_table_visualization.get_column_plot(
real_data,
synthetic_data,
metadata,
column_name,
plot_type,
)


def get_column_pair_plot(real_data, synthetic_data, metadata, table_name, column_names):
def get_column_pair_plot(real_data, synthetic_data, metadata,
table_name, column_names, plot_type=None):
"""Get a plot of the real and synthetic data for a given column pair.
Args:
Expand All @@ -89,19 +102,30 @@ def get_column_pair_plot(real_data, synthetic_data, metadata, table_name, column
The name of the table.
column_names (list[string]):
The names of the two columns to plot.
plot_type (str or None):
The plot to be used. Can choose between ``box``, ``heatmap``, ``scatter`` or ``None``.
If ``None` select between ``box``, ``heatmap`` or ``scatter`` depending on the data
that the column contains, ``scatter`` used for datetime and numerical values,
``heatmap`` for categorical and ``box`` for a mix of both. Defaults to ``None``.
Returns:
plotly.graph_objects._figure.Figure:
2D bivariate distribution plot (i.e. a scatterplot) of the columns.
"""
metadata = metadata.to_dict()['tables'][table_name]
metadata = metadata.tables[table_name]
real_data = real_data[table_name]
synthetic_data = synthetic_data[table_name]
return report.get_column_pair_plot(real_data, synthetic_data, column_names, metadata)
return single_table_visualization.get_column_pair_plot(
real_data,
synthetic_data,
metadata,
column_names,
plot_type
)


def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_table_name,
child_foreign_key, metadata):
child_foreign_key, metadata, plot_type='bar'):
"""Get a plot of the cardinality of the parent-child relationship.
Args:
Expand All @@ -116,12 +140,21 @@ def get_cardinality_plot(real_data, synthetic_data, child_table_name, parent_tab
child_foreign_key (string):
The name of the foreign key column in the child table.
metadata (MultiTableMetadata):
Metadata describing the data
Metadata describing the data.
plot_type (str):
The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'.
Defaults to 'bar'.
Returns:
plotly.graph_objects._figure.Figure
"""
metadata = metadata.to_dict()
return report.get_cardinality_plot(
real_data, synthetic_data, child_table_name, parent_table_name,
child_foreign_key, metadata)
parent_primary_key = metadata.tables[parent_table_name].primary_key
return visualization.get_cardinality_plot(
real_data,
synthetic_data,
child_table_name,
parent_table_name,
child_foreign_key,
parent_primary_key,
plot_type
)

0 comments on commit d8c8101

Please sign in to comment.