Skip to content

Commit

Permalink
Merge pull request #72 from tovrstra/npz
Browse files Browse the repository at this point in the history
NDArraysRegressionFixture: regression on arrays with arbitrary shape.
  • Loading branch information
tarcisiofischer committed Sep 15, 2021
2 parents 4c3fd76 + f03cdce commit 983a6db
Show file tree
Hide file tree
Showing 21 changed files with 996 additions and 12 deletions.
15 changes: 7 additions & 8 deletions src/pytest_regressions/dataframe_regression.py
Expand Up @@ -42,7 +42,7 @@ def _check_data_types(self, key, obtained_column, expected_column):
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Numpy"))
raise ModuleNotFoundError(import_error_message("NumPy"))

__tracebackhide__ = True
obtained_data_type = obtained_column.values.dtype
Expand Down Expand Up @@ -89,7 +89,7 @@ def _check_fn(self, obtained_filename, expected_filename):
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Numpy"))
raise ModuleNotFoundError(import_error_message("NumPy"))
try:
import pandas as pd
except ModuleNotFoundError:
Expand Down Expand Up @@ -123,8 +123,7 @@ def _check_fn(self, obtained_filename, expected_filename):
self._check_data_types(k, obtained_column, expected_column)
self._check_data_shapes(obtained_column, expected_column)

data_type = obtained_column.values.dtype
if data_type in [float, np.float16, np.float32, np.float64]:
if np.issubdtype(obtained_column.values.dtype, np.inexact):
not_close_mask = ~np.isclose(
obtained_column.values,
expected_column.values,
Expand All @@ -138,7 +137,7 @@ def _check_fn(self, obtained_filename, expected_filename):
diff_ids = np.where(not_close_mask)[0]
diff_obtained_data = obtained_column[diff_ids]
diff_expected_data = expected_column[diff_ids]
if data_type == bool:
if obtained_column.values.dtype == bool:
diffs = np.logical_xor(obtained_column, expected_column)[diff_ids]
else:
diffs = np.abs(obtained_column - expected_column)[diff_ids]
Expand Down Expand Up @@ -199,7 +198,7 @@ def check(
will ignore embed_data completely, being useful if a reference file is located
in the session data dir for example.
:param dict tolerances: dict mapping keys from the data_dict to tolerance settings for the
:param dict tolerances: dict mapping keys from the data_frame to tolerance settings for the
given data. Example::
tolerances={'U': Tolerance(atol=1e-2)}
Expand All @@ -223,7 +222,7 @@ def check(
__tracebackhide__ = True

assert type(data_frame) is pd.DataFrame, (
"Only pandas DataFrames are supported on on dataframe_regression fixture.\n"
"Only pandas DataFrames are supported on dataframe_regression fixture.\n"
"Object with type '%s' was given." % (str(type(data_frame)),)
)

Expand All @@ -235,7 +234,7 @@ def check(
# Rejected: timedelta, datetime, objects, zero-terminated bytes, unicode strings and raw data
assert array.dtype not in ["m", "M", "O", "S", "a", "U", "V"], (
"Only numeric data is supported on dataframe_regression fixture.\n"
"Array with type '%s' was given.\n" % (str(array.dtype),)
"Array with type '%s' was given." % (str(array.dtype),)
)

if tolerances is None:
Expand Down

0 comments on commit 983a6db

Please sign in to comment.