diff --git a/devito/finite_differences/elementary.py b/devito/finite_differences/elementary.py index 41f7a20bda..da0a10cfec 100644 --- a/devito/finite_differences/elementary.py +++ b/devito/finite_differences/elementary.py @@ -91,11 +91,19 @@ def root(x): class Min(sympy.Min, Evaluable): - pass + + def _evaluate(self, **kwargs): + args = self._evaluate_args(**kwargs) + assert len(args) == 2 + return self.func(args[0], args[1], evaluate=False) class Max(sympy.Max, Evaluable): - pass + + def _evaluate(self, **kwargs): + args = self._evaluate_args(**kwargs) + assert len(args) == 2 + return self.func(args[0], args[1], evaluate=False) def Id(x): diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 51f418619a..b1da09e96e 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -58,6 +58,13 @@ def msgs(self): def regions(self): return [i for i in self._regions.values() if i is not None] + @property + def headers(self): + """ + No headers needed by default + """ + return {} + def make(self, hs): """ Construct Callables and Calls implementing distributed-memory halo @@ -510,6 +517,14 @@ class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder): remainder() """ + @property + def headers(self): + """ + Overlap Mode uses MIN/MAX that need to be defined + """ + return {'headers': [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')), + ('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))]} + def _make_msg(self, f, hse, key): # Only retain the halos required by the Diag scheme halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple)) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 047acd2c02..c20c13a42c 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -3,7 +3,7 @@ from devito.ir import (Any, Forward, List, Prodder, FindNodes, Transformer, filter_iterations, retrieve_iteration_tree) from devito.passes.iet.engine import iet_pass -from devito.symbolics import MIN, MAX, evalrel +from devito.symbolics import evalrel from devito.tools import split __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions'] @@ -81,7 +81,7 @@ def relax_incr_dimensions(iet, options=None, **kwargs): to: - + """ mapper = {} @@ -107,15 +107,15 @@ def relax_incr_dimensions(iet, options=None, **kwargs): if i.is_Inbound: continue - # The Iteration's maximum is the MIN of (a) the `symbolic_max` of current + # The Iteration's maximum is the Min of (a) the `symbolic_max` of current # Iteration e.g. `x0_blk0 + x0_blk0_size - 1` and (b) the `symbolic_max` # of the current Iteration's root Dimension e.g. `x_M`. The generated - # maximum will be `MIN(x0_blk0 + x0_blk0_size - 1, x_M) + # maximum will be `Min(x0_blk0 + x0_blk0_size - 1, x_M) # In some corner cases an offset may be added (e.g. after CIRE passes) # E.g. assume `i.symbolic_max = x0_blk0 + x0_blk0_size + 1` and # `i.dim.symbolic_max = x0_blk0 + x0_blk0_size - 1` then the generated - # maximum will be `MIN(x0_blk0 + x0_blk0_size + 1, x_M + 2)` + # maximum will be `Min(x0_blk0 + x0_blk0_size + 1, x_M + 2)` root_max = roots_max[i.dim.root] + i.symbolic_max - i.dim.symbolic_max iter_max = evalrel(min, [i.symbolic_max, root_max]) @@ -124,8 +124,8 @@ def relax_incr_dimensions(iet, options=None, **kwargs): if mapper: iet = Transformer(mapper, nested=True).visit(iet) - headers = [('%s(a,b)' % MIN.name, ('(((a) < (b)) ? (a) : (b))')), - ('%s(a,b)' % MAX.name, ('(((a) > (b)) ? (a) : (b))'))] + headers = [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')), + ('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))] else: headers = [] diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index b5d9d2c93d..b3db9dc2f5 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -298,6 +298,7 @@ def make_mpi(iet, mpimode=None, **kwargs): efuncs = sync_heb.efuncs + user_heb.efuncs iet = Transformer(mapper, nested=True).visit(iet) + headers = user_heb.headers # Must drop the PARALLEL tag from the Iterations within which halo # exchanges are performed @@ -312,8 +313,9 @@ def make_mpi(iet, mpimode=None, **kwargs): for n in tree[:tree.index(i)+1]}) break iet = Transformer(mapper, nested=True).visit(iet) + headers.update({'includes': ['mpi.h'], 'efuncs': efuncs}) - return iet, {'includes': ['mpi.h'], 'efuncs': efuncs} + return iet, headers def mpiize(graph, **kwargs): diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 6afe1d0460..a3cbb09e46 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -9,14 +9,14 @@ from devito.tools import (Pickable, as_tuple, is_integer, float2, float3, float4, # noqa double2, double3, double4, int2, int3, int4) +from devito.finite_differences.elementary import Min, Max from devito.types import Symbol __all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', 'FieldFromPointer', # noqa 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'InlineIf', 'Keyword', 'String', 'Macro', 'MacroArgument', - 'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'CEIL', - 'FLOOR', 'MAX', 'MIN', 'Null', 'SizeOf', 'rfunc', 'cast_mapper', - 'BasicWrapperMixin'] + 'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID', + 'Null', 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin'] class CondEq(sympy.Eq): @@ -672,12 +672,6 @@ class CHARP(CastStar): # Some other utility objects - -CEIL = Function('ceil') -FLOOR = Function('floor') -MAX = Function('MAX') -MIN = Function('MIN') - Null = Macro('NULL') # DefFunction, unlike sympy.Function, generates e.g. `sizeof(float)`, not `sizeof(float_)` @@ -691,10 +685,10 @@ def rfunc(func, item, *args): Examples ---------- >> rfunc(min, [a, b, c, d]) - MIN(a, MIN(b, MIN(c, d))) + Min(a, Min(b, Min(c, d))) >> rfunc(max, [a, b, c, d]) - MAX(a, MAX(b, MAX(c, d))) + Max(a, Max(b, Max(c, d))) """ assert func in rfunc_mapper @@ -703,10 +697,10 @@ def rfunc(func, item, *args): if len(args) == 0: return item else: - return rf(item, rfunc(func, *args)) + return rf(item, rfunc(func, *args), evaluate=False) rfunc_mapper = { - min: MIN, - max: MAX, + min: Min, + max: Max, } diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 99f04d5198..694f85b69e 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -283,8 +283,8 @@ def reuse_if_untouched(expr, args, evaluate=False): def evalrel(func=min, input=None, assumptions=None): """ The purpose of this function is two-fold: (i) to reduce the `input` candidates of a - for a MIN/MAX expression based on the given `assumptions` and (ii) return the nested - MIN/MAX expression of the reduced-size input. + for a Min/Max expression based on the given `assumptions` and (ii) return the nested + Min/Max expression of the reduced-size input. Parameters ---------- @@ -307,7 +307,7 @@ def evalrel(func=min, input=None, assumptions=None): >>> c = Symbol('c') >>> d = Symbol('d') >>> evalrel(max, [a, b, c, d], [Le(d, a), Ge(c, b)]) - MAX(a, c) + Max(a, c) """ sfunc = (Min if func is min else Max) # Choose SymPy's Min/Max @@ -353,7 +353,7 @@ def evalrel(func=min, input=None, assumptions=None): mapper = transitive_closure(mapper) input = [i.subs(mapper) for i in input] - # Explore simplification opportunities that may have emerged and generate MIN/MAX + # Explore simplification opportunities that may have emerged and generate Min/Max # expression try: exp = sfunc(*input) # Can it be evaluated or simplified? diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 31b0f3db01..9080c26242 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -101,20 +101,23 @@ def _print_Mod(self, expr): args = ['(%s)' % self._print(a) for a in expr.args] return '%'.join(args) + def _print_Min(self, expr): + """Print Min using devito defined header Min""" + func = 'MIN' if has_integer_args(*expr.args) else 'fmin' + return "%s(%s)" % (func, self._print(expr.args)[1:-1]) + + def _print_Max(self, expr): + """Print Max using devito defined header Max""" + func = 'MAX' if has_integer_args(*expr.args) else 'fmax' + return "%s(%s)" % (func, self._print(expr.args)[1:-1]) + def _print_Abs(self, expr): """Print an absolute value. Use `abs` if can infer it is an Integer""" # AOMPCC errors with abs, always use fabs if isinstance(self.compiler, AOMPCompiler): return "fabs(%s)" % self._print(expr.args[0]) # Check if argument is an integer - is_integer = True - for a in expr.args[0].args: - try: - is_integer = is_integer and np.issubdtype(a.dtype, np.integer) - except AttributeError: - is_integer = is_integer and a.is_Integer - - func = "abs" if is_integer else "fabs" + func = "abs" if has_integer_args(*expr.args[0].args) else "fabs" return "%s(%s)" % (func, self._print(expr.args[0])) def _print_Add(self, expr, order=None): @@ -255,3 +258,30 @@ def ccode(expr, **settings): # to always use the correct one from our printer if Version(sympy.__version__) >= Version("1.11"): setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add) + + +# Check arguements type +def has_integer_args(*args): + """ + Check if expression is Integer. + Used to choose the function printed in the c-code + """ + if len(args) == 0: + return False + + if len(args) == 1: + try: + return np.issubdtype(args[0].dtype, np.integer) + except AttributeError: + return args[0].is_integer + + res = True + for a in args: + try: + if len(a.args) > 0: + res = res and has_integer_args(*a.args) + else: + res = res and has_integer_args(a) + except AttributeError: + res = res and has_integer_args(a) + return res diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 895586b74f..82ed1489a7 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -6,9 +6,10 @@ from cached_property import cached_property from devito.finite_differences import generate_fd_shortcuts +from devito.finite_differences.elementary import floor from devito.mpi import MPI, SparseDistributor from devito.operations import LinearInterpolator, PrecomputedInterpolator -from devito.symbolics import (INT, FLOOR, cast_mapper, indexify, +from devito.symbolics import (INT, cast_mapper, indexify, retrieve_function_carriers) from devito.tools import (ReducerMap, as_tuple, flatten, prod, filter_ordered, memoized_meth, is_integer) @@ -532,7 +533,7 @@ def _coordinate_symbols(self): @cached_property def _coordinate_indices(self): """Symbol for each grid index according to the coordinates.""" - return tuple([INT(FLOOR((c - o) / i.spacing)) + return tuple([INT(floor((c - o) / i.spacing)) for c, o, i in zip(self._coordinate_symbols, self.grid.origin_symbols, self.grid.dimensions[:self.grid.dim])]) diff --git a/examples/performance/00_overview.ipynb b/examples/performance/00_overview.ipynb index 4b49b5af3e..55e0725b74 100644 --- a/examples/performance/00_overview.ipynb +++ b/examples/performance/00_overview.ipynb @@ -340,9 +340,9 @@ " {\n", " for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n", " {\n", - " for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n", + " for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1)\n", + " for (int y = y0_blk0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1)\n", " {\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", @@ -426,17 +426,17 @@ " {\n", " for (int z0_blk0 = z_m; z0_blk0 <= z_M; z0_blk0 += z0_blk0_size)\n", " {\n", - " for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x0_blk1 += x0_blk1_size)\n", + " for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x0_blk1 += x0_blk1_size)\n", " {\n", - " for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y0_blk1 += y0_blk1_size)\n", + " for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y0_blk1 += y0_blk1_size)\n", " {\n", - " for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z0_blk0 + z0_blk0_size - 1, z_M); z0_blk1 += z0_blk1_size)\n", + " for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z_M, z0_blk0 + z0_blk0_size - 1); z0_blk1 += z0_blk1_size)\n", " {\n", - " for (int x = x0_blk1; x <= MIN(x0_blk1 + x0_blk1_size - 1, x_M); x += 1)\n", + " for (int x = x0_blk1; x <= MIN(x_M, x0_blk1 + x0_blk1_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk1; y <= MIN(y0_blk1 + y0_blk1_size - 1, y_M); y += 1)\n", + " for (int y = y0_blk1; y <= MIN(y_M, y0_blk1 + y0_blk1_size - 1); y += 1)\n", " {\n", - " for (int z = z0_blk1; z <= MIN(z0_blk1 + z0_blk1_size - 1, z_M); z += 1)\n", + " for (int z = z0_blk1; z <= MIN(z_M, z0_blk1 + z0_blk1_size - 1); z += 1)\n", " {\n", " u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 1][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 2][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 5][z + 4]/h_y) + (-8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 5][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 7][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 8][z + 4]/h_y) + (8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 1][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 3][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y) + (6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 3][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 6][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 7][z + 4]/h_y))*sin(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n", " }\n", @@ -1272,9 +1272,9 @@ " {\n", " for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n", " {\n", - " for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n", + " for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size + 1, y_M + 2); y += 1, ys += 1)\n", + " for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size + 1); y += 1, ys += 1)\n", " {\n", " #pragma omp simd aligned(u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", @@ -1282,7 +1282,7 @@ " r2[ys][z] = r1*(8.33333333e-2F*(u[t0][x + 4][y + 2][z + 4] - u[t0][x + 4][y + 6][z + 4]) + 6.66666667e-1F*(-u[t0][x + 4][y + 3][z + 4] + u[t0][x + 4][y + 5][z + 4]));\n", " }\n", " }\n", - " for (int y = y0_blk0, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1)\n", + " for (int y = y0_blk0, ys = 0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1)\n", " {\n", " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", @@ -1390,9 +1390,9 @@ " {\n", " for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n", " {\n", - " for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n", + " for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n", + " for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n", " {\n", " for (int yy = yii, ysi = (yy + ys + 2)%(5); yy <= 2; yy += 1, ysi = (yy + ys + 2)%(5))\n", " {\n", @@ -1706,9 +1706,9 @@ " {\n", " for (int y0_blk0 = y_m - 2; y0_blk0 <= y_M + 2; y0_blk0 += y0_blk0_size)\n", " {\n", - " for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M + 2); x += 1)\n", + " for (int x = x0_blk0; x <= MIN(x_M + 2, x0_blk0 + x0_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M + 2); y += 1)\n", + " for (int y = y0_blk0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size - 1); y += 1)\n", " {\n", " #pragma omp simd aligned(u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", @@ -1728,9 +1728,9 @@ " {\n", " for (int y1_blk0 = y_m; y1_blk0 <= y_M; y1_blk0 += y1_blk0_size)\n", " {\n", - " for (int x = x1_blk0; x <= MIN(x1_blk0 + x1_blk0_size - 1, x_M); x += 1)\n", + " for (int x = x1_blk0; x <= MIN(x_M, x1_blk0 + x1_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y1_blk0; y <= MIN(y1_blk0 + y1_blk0_size - 1, y_M); y += 1)\n", + " for (int y = y1_blk0; y <= MIN(y_M, y1_blk0 + y1_blk0_size - 1); y += 1)\n", " {\n", " #pragma omp simd aligned(f,u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", @@ -1788,7 +1788,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.9.16" }, "varInspector": { "cols": { diff --git a/requirements.txt b/requirements.txt index 05bccd6eec..238102d611 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ pip>=9.0.1 numpy>1.16 -sympy>=1.7,<1.12 +sympy>=1.9,<1.12 scipy flake8>=2.1.0 nbval diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 4ffcdb779f..9efa765c70 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -168,9 +168,10 @@ def test_codegen_quality0(): assert len(exprs) == 6 assert all('const long' in str(i) for i in exprs[:-2]) - # Only four access macros necessary, namely `uL0`, `bufL0`, `bufL1` (the - # other three obviously are _POSIX_C_SOURCE, START_TIMER, STOP_TIMER) - assert len(op._headers) == 6 + # Only four access macros necessary, namely `uL0`, `bufL0`, `bufL1` + # MIN/MAX for the efunc args + # (the other three obviously are _POSIX_C_SOURCE, START_TIMER, STOP_TIMER) + assert len(op._headers) == 8 def test_codegen_quality1(): diff --git a/tests/test_pickle.py b/tests/test_pickle.py index ffe27a6ec3..ab1b38184d 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -680,3 +680,30 @@ def test_full_model(): new_ricker = pickle.loads(pkl_ricker) assert np.isclose(np.linalg.norm(ricker.data), np.linalg.norm(new_ricker.data)) # FIXME: fails randomly when using data.flatten() AND numpy is using MKL + + +def test_elemental(): + """ + Tests that elemental function doesn't get reconstructed differently + """ + grid = Grid(shape=(101, 101)) + time_range = TimeAxis(start=0.0, stop=1000.0, num=12) + + nrec = 101 + rec = Receiver(name='rec', grid=grid, npoint=nrec, time_range=time_range) + + u = TimeFunction(name="u", grid=grid, time_order=2, space_order=2) + rec_term = rec.interpolate(expr=u) + + eq = rec_term.evaluate[2] + eq = eq.func(eq.lhs, eq.rhs.args[0]) + + op = Operator(eq) + + pkl_op = pickle.dumps(op) + new_op = pickle.loads(pkl_op) + + op.cfunction + new_op.cfunction + + assert str(op) == str(new_op) diff --git a/tests/test_skewing.py b/tests/test_skewing.py index c2f393cc88..d5e1711a3c 100644 --- a/tests/test_skewing.py +++ b/tests/test_skewing.py @@ -2,8 +2,7 @@ import numpy as np from conftest import assert_blocking -from devito.symbolics import MIN -from devito import Grid, Dimension, Eq, Function, TimeFunction, Operator, norm # noqa +from devito import Grid, Dimension, Eq, Function, TimeFunction, Operator, norm, Min # noqa from devito.ir import Expression, Iteration, FindNodes @@ -55,10 +54,10 @@ def test_skewed_bounds(self, expr, expected, norm_u, norm_v): assert iters[1].symbolic_max == (iters[1].dim.parent.symbolic_max + time) assert iters[2].symbolic_min == iters[2].dim.symbolic_min - assert iters[2].symbolic_max == MIN(iters[0].dim + iters[0].dim.symbolic_incr + assert iters[2].symbolic_max == Min(iters[0].dim + iters[0].dim.symbolic_incr - 1, iters[0].dim.symbolic_max + time) assert iters[3].symbolic_min == iters[3].dim.symbolic_min - assert iters[3].symbolic_max == MIN(iters[1].dim + iters[1].dim.symbolic_incr + assert iters[3].symbolic_max == Min(iters[1].dim + iters[1].dim.symbolic_incr - 1, iters[1].dim.symbolic_max + time) assert iters[4].symbolic_min == (iters[4].dim.symbolic_min) assert iters[4].symbolic_max == (iters[4].dim.symbolic_max) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 4ef01428ef..ae0f76f147 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -6,11 +6,11 @@ from sympy import Expr, Symbol from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa - Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos) + Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max) from devito.ir import Expression, FindNodes from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, FieldFromPointer, - FieldFromComposite, IntDiv, MIN, MAX, ccode, uxreplace) + FieldFromComposite, IntDiv, ccode, uxreplace) from devito.types import Array, LocalObject, Object @@ -367,6 +367,7 @@ def test_multibounds_op(self): f.data[:] = 0.1 eqns = [Eq(f.forward, f.laplace + f * evalrel(min, [f, b, c, d]))] + op = Operator(eqns, opt=('advanced')) op.apply(time_M=5) fnorm = norm(f) @@ -383,25 +384,25 @@ def test_multibounds_op(self): assert fnorm == fnorm2 @pytest.mark.parametrize('op, expr, assumptions, expected', [ - ([min, '[a, b, c, d]', '[]', 'MIN(a, MIN(b, MIN(c, d)))']), - ([max, '[a, b, c, d]', '[]', 'MAX(a, MAX(b, MAX(c, d)))']), + ([min, '[a, b, c, d]', '[]', 'Min(a, Min(b, Min(c, d)))']), + ([max, '[a, b, c, d]', '[]', 'Max(a, Max(b, Max(c, d)))']), ([min, '[a]', '[]', 'a']), - ([min, '[a, b]', '[Le(d, a), Ge(c, b)]', 'MIN(a, b)']), - ([min, '[a, b, c]', '[]', 'MIN(a, MIN(b, c))']), - ([min, '[a, b, c, d]', '[Le(d, a), Ge(c, b)]', 'MIN(b, d)']), + ([min, '[a, b]', '[Le(d, a), Ge(c, b)]', 'Min(a, b)']), + ([min, '[a, b, c]', '[]', 'Min(a, Min(b, c))']), + ([min, '[a, b, c, d]', '[Le(d, a), Ge(c, b)]', 'Min(b, d)']), ([min, '[a, b, c, d]', '[Ge(a, b), Ge(d, a), Ge(b, c)]', 'c']), ([max, '[a]', '[Le(a, a)]', 'a']), ([max, '[a, b]', '[Le(a, b)]', 'b']), - ([max, '[a, b, c]', '[Le(c, b), Le(c, a)]', 'MAX(a, b)']), + ([max, '[a, b, c]', '[Le(c, b), Le(c, a)]', 'Max(a, b)']), ([max, '[a, b, c, d]', '[Ge(a, b), Ge(d, a), Ge(b, c)]', 'd']), - ([max, '[a, b, c, d]', '[Ge(a, b), Le(b, c)]', 'MAX(a, MAX(c, d))']), - ([max, '[a, b, c, d]', '[Ge(a, b), Le(c, b)]', 'MAX(a, d)']), - ([max, '[a, b, c, d]', '[Ge(b, a), Ge(a, b)]', 'MAX(a, MAX(c, d))']), - ([min, '[a, b, c, d]', '[Ge(b, a), Ge(a, b), Le(c, b), Le(b, a)]', 'MIN(c, d)']), + ([max, '[a, b, c, d]', '[Ge(a, b), Le(b, c)]', 'Max(a, Max(c, d))']), + ([max, '[a, b, c, d]', '[Ge(a, b), Le(c, b)]', 'Max(a, d)']), + ([max, '[a, b, c, d]', '[Ge(b, a), Ge(a, b)]', 'Max(a, Max(c, d))']), + ([min, '[a, b, c, d]', '[Ge(b, a), Ge(a, b), Le(c, b), Le(b, a)]', 'Min(c, d)']), ([min, '[a, b, c, d]', '[Ge(b, a), Ge(a, b), Le(c, b), Le(b, d)]', 'c']), - ([min, '[a, b, c, d]', '[Ge(b, a + d)]', 'MIN(a, MIN(c, d))']), - ([min, '[a, b, c, d]', '[Lt(b + a, d)]', 'MIN(a, MIN(b, c))']), - ([max, '[a, b, c, d]', '[Lt(b + a, d)]', 'MAX(c, d)']), + ([min, '[a, b, c, d]', '[Ge(b, a + d)]', 'Min(a, Min(c, d))']), + ([min, '[a, b, c, d]', '[Lt(b + a, d)]', 'Min(a, Min(b, c))']), + ([max, '[a, b, c, d]', '[Lt(b + a, d)]', 'Max(c, d)']), ([max, '[a, b, c, d]', '[Gt(a, b + c + d)]', 'a']), ]) def test_relations_w_complex_assumptions(self, op, expr, assumptions, expected): @@ -414,37 +415,36 @@ def test_relations_w_complex_assumptions(self, op, expr, assumptions, expected): eqn = eval(expr) assumptions = eval(assumptions) - expected = eval(expected) - assert evalrel(op, eqn, assumptions) == expected + assert str(evalrel(op, eqn, assumptions)) == expected @pytest.mark.parametrize('op, expr, assumptions, expected', [ ([min, '[a, b, c, d]', '[Ge(b, a), Ge(a, b), Le(c, b), Le(b, d)]', 'c']), - ([min, '[a, b, c, d]', '[Ge(b, a + d)]', 'MIN(a, MIN(b, MIN(c, d)))']), - ([min, '[a, b, c, d]', '[Ge(c, a + d)]', 'MIN(a, b)']), - ([max, '[a, b, c, d]', '[Ge(c, a + d), Gt(b, a + d)]', 'MAX(b, d)']), + ([min, '[a, b, c, d]', '[Ge(b, a + d)]', 'Min(a, Min(b, Min(c, d)))']), + ([min, '[a, b, c, d]', '[Ge(c, a + d)]', 'Min(a, b)']), + ([max, '[a, b, c, d]', '[Ge(c, a + d), Gt(b, a + d)]', 'Max(b, d)']), ([max, '[a, b, c, d]', '[Ge(a + d, b), Gt(b, a + d)]', - 'MAX(a, MAX(b, MAX(c, d)))']), - ([max, '[a, b, c, d]', '[Le(c, a + d)]', 'MAX(a, MAX(b, MAX(c, d)))']), - ([max, '[a, b, c, d]', '[Le(c, d), Le(a, b)]', 'MAX(b, d)']), - ([max, '[a, b, c, d]', '[Le(c, d), Le(d, c)]', 'MAX(a, MAX(b, c))']), - ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'MIN(b, d)']), - ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'MIN(b, d)']), - ([min, '[a, b, c, d]', '[Gt(c, d).negated, Ge(a, b).negated]', 'MIN(a, c)']), - ([min, '[a, b, c, d]', '[Gt(c, d).reversed, Ge(a, b).reversed]', 'MIN(b, d)']), - ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'MIN(b, d)']), - ([max, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'MAX(a, c)']), - ([max, '[a, b, c, d]', '[Gt(c, d).negated, Ge(a, b).negated]', 'MAX(b, d)']), - ([max, '[a, b, c, d]', '[Gt(c, d).reversed, Ge(a, b).reversed]', 'MAX(a, c)']), - ([max, '[a, b, c, d]', '[Lt(c, d).reversed, Le(a, b).reversed]', 'MAX(b, d)']), - ([max, '[a, b, c, d]', '[Gt(c, d + a).negated]', 'MAX(a, MAX(b, MAX(c, d)))']), - ([max, '[a, b, c, d]', '[Lt(c, d + a).negated]', 'MAX(b, d)']), - ([max, '[a, b, c, d]', '[Le(c, d + a).negated]', 'MAX(b, d)']), + 'Max(a, Max(b, Max(c, d)))']), + ([max, '[a, b, c, d]', '[Le(c, a + d)]', 'Max(a, Max(b, Max(c, d)))']), + ([max, '[a, b, c, d]', '[Le(c, d), Le(a, b)]', 'Max(b, d)']), + ([max, '[a, b, c, d]', '[Le(c, d), Le(d, c)]', 'Max(a, Max(b, c))']), + ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'Min(b, d)']), + ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'Min(b, d)']), + ([min, '[a, b, c, d]', '[Gt(c, d).negated, Ge(a, b).negated]', 'Min(a, c)']), + ([min, '[a, b, c, d]', '[Gt(c, d).reversed, Ge(a, b).reversed]', 'Min(b, d)']), + ([min, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'Min(b, d)']), + ([max, '[a, b, c, d]', '[Le(c, d).negated, Le(a, b).negated]', 'Max(a, c)']), + ([max, '[a, b, c, d]', '[Gt(c, d).negated, Ge(a, b).negated]', 'Max(b, d)']), + ([max, '[a, b, c, d]', '[Gt(c, d).reversed, Ge(a, b).reversed]', 'Max(a, c)']), + ([max, '[a, b, c, d]', '[Lt(c, d).reversed, Le(a, b).reversed]', 'Max(b, d)']), + ([max, '[a, b, c, d]', '[Gt(c, d + a).negated]', 'Max(a, Max(b, Max(c, d)))']), + ([max, '[a, b, c, d]', '[Lt(c, d + a).negated]', 'Max(b, d)']), + ([max, '[a, b, c, d]', '[Le(c, d + a).negated]', 'Max(b, d)']), ([max, '[a, b, c, d]', '[Le(c + b, d + a).negated]', - 'MAX(a, MAX(b, MAX(c, d)))']), + 'Max(a, Max(b, Max(c, d)))']), ([max, '[a, b, c, d, e]', '[Gt(a, b + c + e)]', - 'MAX(a, MAX(b, MAX(c, MAX(d, e))))']), + 'Max(a, Max(b, Max(c, Max(d, e))))']), ([max, '[a, b, c, d, e]', '[Ge(c, a), Ge(b, a), Ge(a, c), Ge(e, c), Ge(d, e)]', - 'MAX(b, d)']), + 'Max(b, d)']), ]) def test_relations_w_complex_assumptions_II(self, op, expr, assumptions, expected): """ @@ -457,14 +457,13 @@ def test_relations_w_complex_assumptions_II(self, op, expr, assumptions, expecte eqn = eval(expr) assumptions = eval(assumptions) - expected = eval(expected) - assert evalrel(op, eqn, assumptions) == expected + assert str(evalrel(op, eqn, assumptions)) == expected @pytest.mark.parametrize('op, expr, assumptions, expected', [ ([min, '[a, b, c, d]', '[Ge(b, a)]', 'a']), - ([min, '[a, b, c, d]', '[Ge(b, d)]', 'MIN(a, d)']), - ([min, '[a, b, c, d]', '[Ge(c, a + d)]', 'MIN(a, b)']), - ([max, '[a, b, c, d, e]', 'None', 'MAX(e, d)']), + ([min, '[a, b, c, d]', '[Ge(b, d)]', 'Min(a, d)']), + ([min, '[a, b, c, d]', '[Ge(c, a + d)]', 'Min(a, b)']), + ([max, '[a, b, c, d, e]', 'None', 'Max(e, d)']), ]) def test_assumptions(self, op, expr, assumptions, expected): """