forked from ESSS/pytest-regressions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataframe_regression.py
262 lines (217 loc) · 9.91 KB
/
dataframe_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from pytest_regressions.common import perform_regression_check, import_error_message
class DataFrameRegressionFixture:
"""
Pandas DataFrame Regression fixture implementation used on dataframe_regression fixture.
"""
DISPLAY_PRECISION = 17 # Decimal places
DISPLAY_WIDTH = 1000 # Max. Chars on outputs
DISPLAY_MAX_COLUMNS = 1000 # Max. Number of columns (see #3)
def __init__(self, datadir, original_datadir, request):
"""
:type datadir: Path
:type original_datadir: Path
:type request: FixtureRequest
"""
self._tolerances_dict = {}
self._default_tolerance = {}
self.request = request
self.datadir = datadir
self.original_datadir = original_datadir
self._force_regen = False
self._with_test_class_names = False
self._pandas_display_options = (
"display.precision",
DataFrameRegressionFixture.DISPLAY_PRECISION,
"display.width",
DataFrameRegressionFixture.DISPLAY_WIDTH,
"display.max_columns",
DataFrameRegressionFixture.DISPLAY_MAX_COLUMNS,
)
def _check_data_types(self, key, obtained_column, expected_column):
"""
Check if data type of obtained and expected columns are the same. Fail if not.
Helper method used in _check_fn method.
"""
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("NumPy"))
__tracebackhide__ = True
obtained_data_type = obtained_column.values.dtype
expected_data_type = expected_column.values.dtype
if obtained_data_type != expected_data_type:
# Check if both data types are comparable as numbers (float, int, short, bytes, etc...)
if np.issubdtype(obtained_data_type, np.number) and np.issubdtype(
expected_data_type, np.number
):
return
# In case they are not, assume they are not comparable
error_msg = (
"Data type for data %s of obtained and expected are not the same.\n"
"Obtained: %s\n"
"Expected: %s\n" % (key, obtained_data_type, expected_data_type)
)
raise AssertionError(error_msg)
def _check_data_shapes(self, obtained_column, expected_column):
"""
Check if obtained and expected columns have the same size.
Helper method used in _check_fn method.
"""
__tracebackhide__ = True
obtained_data_shape = obtained_column.values.shape
expected_data_shape = expected_column.values.shape
if obtained_data_shape != expected_data_shape:
error_msg = (
"Obtained and expected data shape are not the same.\n"
"Obtained: %s\n"
"Expected: %s\n" % (obtained_data_shape, expected_data_shape)
)
raise AssertionError(error_msg)
def _check_fn(self, obtained_filename, expected_filename):
"""
Check if dict contents dumped to a file match the contents in expected file.
:param str obtained_filename:
:param str expected_filename:
"""
try:
import numpy as np
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("NumPy"))
try:
import pandas as pd
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Pandas"))
__tracebackhide__ = True
obtained_data = pd.read_csv(str(obtained_filename))
expected_data = pd.read_csv(str(expected_filename))
comparison_tables_dict = {}
for k in obtained_data.keys():
obtained_column = obtained_data[k]
expected_column = expected_data.get(k)
if expected_column is None:
error_msg = f"Could not find key '{k}' in the expected results.\n"
error_msg += "Keys in the obtained data table: ["
for k in obtained_data.keys():
error_msg += f"'{k}', "
error_msg += "]\n"
error_msg += "Keys in the expected data table: ["
for k in expected_data.keys():
error_msg += f"'{k}', "
error_msg += "]\n"
error_msg += "To update values, use --force-regen option.\n\n"
raise AssertionError(error_msg)
tolerance_args = self._tolerances_dict.get(k, self._default_tolerance)
self._check_data_types(k, obtained_column, expected_column)
self._check_data_shapes(obtained_column, expected_column)
if np.issubdtype(obtained_column.values.dtype, np.inexact):
not_close_mask = ~np.isclose(
obtained_column.values,
expected_column.values,
equal_nan=True,
**tolerance_args,
)
else:
not_close_mask = obtained_column.values != expected_column.values
if np.any(not_close_mask):
diff_ids = np.where(not_close_mask)[0]
diff_obtained_data = obtained_column[diff_ids]
diff_expected_data = expected_column[diff_ids]
if obtained_column.values.dtype == bool:
diffs = np.logical_xor(obtained_column, expected_column)[diff_ids]
else:
diffs = np.abs(obtained_column - expected_column)[diff_ids]
comparison_table = pd.concat(
[diff_obtained_data, diff_expected_data, diffs], axis=1
)
comparison_table.columns = [f"obtained_{k}", f"expected_{k}", "diff"]
comparison_tables_dict[k] = comparison_table
if len(comparison_tables_dict) > 0:
error_msg = "Values are not sufficiently close.\n"
error_msg += "To update values, use --force-regen option.\n\n"
for k, comparison_table in comparison_tables_dict.items():
error_msg += f"{k}:\n{comparison_table}\n\n"
raise AssertionError(error_msg)
def _dump_fn(self, data_object, filename):
"""
Dump dict contents to the given filename
:param pd.DataFrame data_object:
:param str filename:
"""
data_object.to_csv(
str(filename),
float_format=f"%.{DataFrameRegressionFixture.DISPLAY_PRECISION}g",
)
def check(
self,
data_frame,
basename=None,
fullpath=None,
tolerances=None,
default_tolerance=None,
):
"""
Checks a pandas dataframe, containing only numeric data, against a previously recorded version, or generate a new file.
Example::
data_frame = pandas.DataFrame.from_dict({
'U_gas': U[0][positions],
'U_liquid': U[1][positions],
'gas_vol_frac [-]': vol_frac[0][positions],
'liquid_vol_frac [-]': vol_frac[1][positions],
'P': Pa_to_bar(P)[positions],
})
dataframe_regression.check(data_frame)
:param pandas.DataFrame data_frame: pandas DataFrame containing data for regression check.
:param str basename: basename of the file to test/record. If not given the name
of the test is used.
:param str fullpath: complete path to use as a reference file. This option
will ignore embed_data completely, being useful if a reference file is located
in the session data dir for example.
:param dict tolerances: dict mapping keys from the data_frame to tolerance settings for the
given data. Example::
tolerances={'U': Tolerance(atol=1e-2)}
:param dict default_tolerance: dict mapping the default tolerance for the current check
call. Example::
default_tolerance=dict(atol=1e-7, rtol=1e-18).
If not provided, will use defaults from numpy's ``isclose`` function.
``basename`` and ``fullpath`` are exclusive.
"""
try:
import pandas as pd
except ModuleNotFoundError:
raise ModuleNotFoundError(import_error_message("Pandas"))
import functools
__tracebackhide__ = True
assert type(data_frame) is pd.DataFrame, (
"Only pandas DataFrames are supported on dataframe_regression fixture.\n"
"Object with type '%s' was given." % (str(type(data_frame)),)
)
for column in data_frame.columns:
array = data_frame[column]
# Skip assertion if an array of strings
if (array.dtype == "O") and (type(array[0]) is str):
continue
# Rejected: timedelta, datetime, objects, zero-terminated bytes, unicode strings and raw data
assert array.dtype not in ["m", "M", "O", "S", "a", "U", "V"], (
"Only numeric data is supported on dataframe_regression fixture.\n"
"Array with type '%s' was given." % (str(array.dtype),)
)
if tolerances is None:
tolerances = {}
self._tolerances_dict = tolerances
if default_tolerance is None:
default_tolerance = {}
self._default_tolerance = default_tolerance
dump_fn = functools.partial(self._dump_fn, data_frame)
with pd.option_context(*self._pandas_display_options):
perform_regression_check(
datadir=self.datadir,
original_datadir=self.original_datadir,
request=self.request,
check_fn=self._check_fn,
dump_fn=dump_fn,
extension=".csv",
basename=basename,
fullpath=fullpath,
force_regen=self._force_regen,
with_test_class_names=self._with_test_class_names,
)