Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NDArraysRegressionFixture: regression on arrays with arbitrary shape. #72

Merged
merged 4 commits into from Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/pytest_regressions/dataframe_regression.py
Expand Up @@ -42,7 +42,7 @@ def _check_data_types(self, key, obtained_column, expected_column):
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Numpy"))
raise ModuleNotFoundError(import_error_message("NumPy"))

__tracebackhide__ = True
obtained_data_type = obtained_column.values.dtype
Expand Down Expand Up @@ -89,7 +89,7 @@ def _check_fn(self, obtained_filename, expected_filename):
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Numpy"))
raise ModuleNotFoundError(import_error_message("NumPy"))
try:
import pandas as pd
except ModuleNotFoundError:
Expand Down Expand Up @@ -123,8 +123,7 @@ def _check_fn(self, obtained_filename, expected_filename):
self._check_data_types(k, obtained_column, expected_column)
self._check_data_shapes(obtained_column, expected_column)

data_type = obtained_column.values.dtype
if data_type in [float, np.float16, np.float32, np.float64]:
if np.issubdtype(obtained_column.values.dtype, np.inexact):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep in mind that this will change behavior a bit, because complex numbers are also np.inexact:

>>> a = np.array([], dtype=np.complex128)
>>> np.issubdtype(a.dtype, np.inexact)
True

Perhaps add a test for this case (ignore me if you already did, by the time I'm writing this, I didn't finished the review yet), just to make sure it doesn't crash or anything...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. I wanted to support complex numbers and there is a unit test for it.

not_close_mask = ~np.isclose(
obtained_column.values,
expected_column.values,
Expand All @@ -138,7 +137,7 @@ def _check_fn(self, obtained_filename, expected_filename):
diff_ids = np.where(not_close_mask)[0]
diff_obtained_data = obtained_column[diff_ids]
diff_expected_data = expected_column[diff_ids]
if data_type == bool:
if obtained_column.values.dtype == bool:
diffs = np.logical_xor(obtained_column, expected_column)[diff_ids]
else:
diffs = np.abs(obtained_column - expected_column)[diff_ids]
Expand Down Expand Up @@ -199,7 +198,7 @@ def check(
will ignore embed_data completely, being useful if a reference file is located
in the session data dir for example.

:param dict tolerances: dict mapping keys from the data_dict to tolerance settings for the
:param dict tolerances: dict mapping keys from the data_frame to tolerance settings for the
given data. Example::

tolerances={'U': Tolerance(atol=1e-2)}
Expand All @@ -223,7 +222,7 @@ def check(
__tracebackhide__ = True

assert type(data_frame) is pd.DataFrame, (
"Only pandas DataFrames are supported on on dataframe_regression fixture.\n"
"Only pandas DataFrames are supported on dataframe_regression fixture.\n"
"Object with type '%s' was given." % (str(type(data_frame)),)
)

Expand All @@ -235,7 +234,7 @@ def check(
# Rejected: timedelta, datetime, objects, zero-terminated bytes, unicode strings and raw data
assert array.dtype not in ["m", "M", "O", "S", "a", "U", "V"], (
"Only numeric data is supported on dataframe_regression fixture.\n"
"Array with type '%s' was given.\n" % (str(array.dtype),)
"Array with type '%s' was given." % (str(array.dtype),)
)

if tolerances is None:
Expand Down