Skip to content

Commit

Permalink
Merge pull request #2052 from devitocodes/fix-floor
Browse files Browse the repository at this point in the history
symbolics: use devito floor instead of Undefined Function
  • Loading branch information
mloubout committed Feb 14, 2023
2 parents 405b113 + fb1197b commit 657857a
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 108 deletions.
12 changes: 10 additions & 2 deletions devito/finite_differences/elementary.py
Expand Up @@ -91,11 +91,19 @@ def root(x):


class Min(sympy.Min, Evaluable):
pass

def _evaluate(self, **kwargs):
args = self._evaluate_args(**kwargs)
assert len(args) == 2
return self.func(args[0], args[1], evaluate=False)


class Max(sympy.Max, Evaluable):
pass

def _evaluate(self, **kwargs):
args = self._evaluate_args(**kwargs)
assert len(args) == 2
return self.func(args[0], args[1], evaluate=False)


def Id(x):
Expand Down
15 changes: 15 additions & 0 deletions devito/mpi/routines.py
Expand Up @@ -58,6 +58,13 @@ def msgs(self):
def regions(self):
return [i for i in self._regions.values() if i is not None]

@property
def headers(self):
"""
No headers needed by default
"""
return {}

def make(self, hs):
"""
Construct Callables and Calls implementing distributed-memory halo
Expand Down Expand Up @@ -510,6 +517,14 @@ class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder):
remainder()
"""

@property
def headers(self):
"""
Overlap Mode uses MIN/MAX that need to be defined
"""
return {'headers': [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),
('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))]}

def _make_msg(self, f, hse, key):
# Only retain the halos required by the Diag scheme
halos = sorted(i for i in hse.halos if isinstance(i.dim, tuple))
Expand Down
14 changes: 7 additions & 7 deletions devito/passes/iet/misc.py
Expand Up @@ -3,7 +3,7 @@
from devito.ir import (Any, Forward, List, Prodder, FindNodes, Transformer,
filter_iterations, retrieve_iteration_tree)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import MIN, MAX, evalrel
from devito.symbolics import evalrel
from devito.tools import split

__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions']
Expand Down Expand Up @@ -81,7 +81,7 @@ def relax_incr_dimensions(iet, options=None, **kwargs):
to:
<Iteration x0_blk0; (x_m, x_M, x0_blk0_size)>
<Iteration x; (x0_blk0, MIN(x_M, x0_blk0 + x0_blk0_size - 1)), 1)>
<Iteration x; (x0_blk0, Min(x_M, x0_blk0 + x0_blk0_size - 1)), 1)>
"""
mapper = {}
Expand All @@ -107,15 +107,15 @@ def relax_incr_dimensions(iet, options=None, **kwargs):
if i.is_Inbound:
continue

# The Iteration's maximum is the MIN of (a) the `symbolic_max` of current
# The Iteration's maximum is the Min of (a) the `symbolic_max` of current
# Iteration e.g. `x0_blk0 + x0_blk0_size - 1` and (b) the `symbolic_max`
# of the current Iteration's root Dimension e.g. `x_M`. The generated
# maximum will be `MIN(x0_blk0 + x0_blk0_size - 1, x_M)
# maximum will be `Min(x0_blk0 + x0_blk0_size - 1, x_M)

# In some corner cases an offset may be added (e.g. after CIRE passes)
# E.g. assume `i.symbolic_max = x0_blk0 + x0_blk0_size + 1` and
# `i.dim.symbolic_max = x0_blk0 + x0_blk0_size - 1` then the generated
# maximum will be `MIN(x0_blk0 + x0_blk0_size + 1, x_M + 2)`
# maximum will be `Min(x0_blk0 + x0_blk0_size + 1, x_M + 2)`

root_max = roots_max[i.dim.root] + i.symbolic_max - i.dim.symbolic_max
iter_max = evalrel(min, [i.symbolic_max, root_max])
Expand All @@ -124,8 +124,8 @@ def relax_incr_dimensions(iet, options=None, **kwargs):
if mapper:
iet = Transformer(mapper, nested=True).visit(iet)

headers = [('%s(a,b)' % MIN.name, ('(((a) < (b)) ? (a) : (b))')),
('%s(a,b)' % MAX.name, ('(((a) > (b)) ? (a) : (b))'))]
headers = [('MIN(a,b)', ('(((a) < (b)) ? (a) : (b))')),
('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))]
else:
headers = []

Expand Down
4 changes: 3 additions & 1 deletion devito/passes/iet/mpi.py
Expand Up @@ -298,6 +298,7 @@ def make_mpi(iet, mpimode=None, **kwargs):

efuncs = sync_heb.efuncs + user_heb.efuncs
iet = Transformer(mapper, nested=True).visit(iet)
headers = user_heb.headers

# Must drop the PARALLEL tag from the Iterations within which halo
# exchanges are performed
Expand All @@ -312,8 +313,9 @@ def make_mpi(iet, mpimode=None, **kwargs):
for n in tree[:tree.index(i)+1]})
break
iet = Transformer(mapper, nested=True).visit(iet)
headers.update({'includes': ['mpi.h'], 'efuncs': efuncs})

return iet, {'includes': ['mpi.h'], 'efuncs': efuncs}
return iet, headers


def mpiize(graph, **kwargs):
Expand Down
22 changes: 8 additions & 14 deletions devito/symbolics/extended_sympy.py
Expand Up @@ -9,14 +9,14 @@

from devito.tools import (Pickable, as_tuple, is_integer, float2, float3, float4, # noqa
double2, double3, double4, int2, int3, int4)
from devito.finite_differences.elementary import Min, Max
from devito.types import Symbol

__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', 'FieldFromPointer', # noqa
'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast',
'DefFunction', 'InlineIf', 'Keyword', 'String', 'Macro', 'MacroArgument',
'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'CEIL',
'FLOOR', 'MAX', 'MIN', 'Null', 'SizeOf', 'rfunc', 'cast_mapper',
'BasicWrapperMixin']
'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID',
'Null', 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -672,12 +672,6 @@ class CHARP(CastStar):


# Some other utility objects

CEIL = Function('ceil')
FLOOR = Function('floor')
MAX = Function('MAX')
MIN = Function('MIN')

Null = Macro('NULL')

# DefFunction, unlike sympy.Function, generates e.g. `sizeof(float)`, not `sizeof(float_)`
Expand All @@ -691,10 +685,10 @@ def rfunc(func, item, *args):
Examples
----------
>> rfunc(min, [a, b, c, d])
MIN(a, MIN(b, MIN(c, d)))
Min(a, Min(b, Min(c, d)))
>> rfunc(max, [a, b, c, d])
MAX(a, MAX(b, MAX(c, d)))
Max(a, Max(b, Max(c, d)))
"""

assert func in rfunc_mapper
Expand All @@ -703,10 +697,10 @@ def rfunc(func, item, *args):
if len(args) == 0:
return item
else:
return rf(item, rfunc(func, *args))
return rf(item, rfunc(func, *args), evaluate=False)


rfunc_mapper = {
min: MIN,
max: MAX,
min: Min,
max: Max,
}
8 changes: 4 additions & 4 deletions devito/symbolics/manipulation.py
Expand Up @@ -283,8 +283,8 @@ def reuse_if_untouched(expr, args, evaluate=False):
def evalrel(func=min, input=None, assumptions=None):
"""
The purpose of this function is two-fold: (i) to reduce the `input` candidates of a
for a MIN/MAX expression based on the given `assumptions` and (ii) return the nested
MIN/MAX expression of the reduced-size input.
for a Min/Max expression based on the given `assumptions` and (ii) return the nested
Min/Max expression of the reduced-size input.
Parameters
----------
Expand All @@ -307,7 +307,7 @@ def evalrel(func=min, input=None, assumptions=None):
>>> c = Symbol('c')
>>> d = Symbol('d')
>>> evalrel(max, [a, b, c, d], [Le(d, a), Ge(c, b)])
MAX(a, c)
Max(a, c)
"""
sfunc = (Min if func is min else Max) # Choose SymPy's Min/Max

Expand Down Expand Up @@ -353,7 +353,7 @@ def evalrel(func=min, input=None, assumptions=None):
mapper = transitive_closure(mapper)
input = [i.subs(mapper) for i in input]

# Explore simplification opportunities that may have emerged and generate MIN/MAX
# Explore simplification opportunities that may have emerged and generate Min/Max
# expression
try:
exp = sfunc(*input) # Can it be evaluated or simplified?
Expand Down
46 changes: 38 additions & 8 deletions devito/symbolics/printer.py
Expand Up @@ -101,20 +101,23 @@ def _print_Mod(self, expr):
args = ['(%s)' % self._print(a) for a in expr.args]
return '%'.join(args)

def _print_Min(self, expr):
"""Print Min using devito defined header Min"""
func = 'MIN' if has_integer_args(*expr.args) else 'fmin'
return "%s(%s)" % (func, self._print(expr.args)[1:-1])

def _print_Max(self, expr):
"""Print Max using devito defined header Max"""
func = 'MAX' if has_integer_args(*expr.args) else 'fmax'
return "%s(%s)" % (func, self._print(expr.args)[1:-1])

def _print_Abs(self, expr):
"""Print an absolute value. Use `abs` if can infer it is an Integer"""
# AOMPCC errors with abs, always use fabs
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
is_integer = True
for a in expr.args[0].args:
try:
is_integer = is_integer and np.issubdtype(a.dtype, np.integer)
except AttributeError:
is_integer = is_integer and a.is_Integer

func = "abs" if is_integer else "fabs"
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))

def _print_Add(self, expr, order=None):
Expand Down Expand Up @@ -255,3 +258,30 @@ def ccode(expr, **settings):
# to always use the correct one from our printer
if Version(sympy.__version__) >= Version("1.11"):
setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add)


# Check arguements type
def has_integer_args(*args):
"""
Check if expression is Integer.
Used to choose the function printed in the c-code
"""
if len(args) == 0:
return False

if len(args) == 1:
try:
return np.issubdtype(args[0].dtype, np.integer)
except AttributeError:
return args[0].is_integer

res = True
for a in args:
try:
if len(a.args) > 0:
res = res and has_integer_args(*a.args)
else:
res = res and has_integer_args(a)
except AttributeError:
res = res and has_integer_args(a)
return res
5 changes: 3 additions & 2 deletions devito/types/sparse.py
Expand Up @@ -6,9 +6,10 @@
from cached_property import cached_property

from devito.finite_differences import generate_fd_shortcuts
from devito.finite_differences.elementary import floor
from devito.mpi import MPI, SparseDistributor
from devito.operations import LinearInterpolator, PrecomputedInterpolator
from devito.symbolics import (INT, FLOOR, cast_mapper, indexify,
from devito.symbolics import (INT, cast_mapper, indexify,
retrieve_function_carriers)
from devito.tools import (ReducerMap, as_tuple, flatten, prod, filter_ordered,
memoized_meth, is_integer)
Expand Down Expand Up @@ -532,7 +533,7 @@ def _coordinate_symbols(self):
@cached_property
def _coordinate_indices(self):
"""Symbol for each grid index according to the coordinates."""
return tuple([INT(FLOOR((c - o) / i.spacing))
return tuple([INT(floor((c - o) / i.spacing))
for c, o, i in zip(self._coordinate_symbols,
self.grid.origin_symbols,
self.grid.dimensions[:self.grid.dim])])
Expand Down
36 changes: 18 additions & 18 deletions examples/performance/00_overview.ipynb
Expand Up @@ -340,9 +340,9 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1)\n",
" for (int y = y0_blk0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1)\n",
" {\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
Expand Down Expand Up @@ -426,17 +426,17 @@
" {\n",
" for (int z0_blk0 = z_m; z0_blk0 <= z_M; z0_blk0 += z0_blk0_size)\n",
" {\n",
" for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x0_blk1 += x0_blk1_size)\n",
" for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x0_blk1 += x0_blk1_size)\n",
" {\n",
" for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y0_blk1 += y0_blk1_size)\n",
" for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y0_blk1 += y0_blk1_size)\n",
" {\n",
" for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z0_blk0 + z0_blk0_size - 1, z_M); z0_blk1 += z0_blk1_size)\n",
" for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z_M, z0_blk0 + z0_blk0_size - 1); z0_blk1 += z0_blk1_size)\n",
" {\n",
" for (int x = x0_blk1; x <= MIN(x0_blk1 + x0_blk1_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk1; x <= MIN(x_M, x0_blk1 + x0_blk1_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk1; y <= MIN(y0_blk1 + y0_blk1_size - 1, y_M); y += 1)\n",
" for (int y = y0_blk1; y <= MIN(y_M, y0_blk1 + y0_blk1_size - 1); y += 1)\n",
" {\n",
" for (int z = z0_blk1; z <= MIN(z0_blk1 + z0_blk1_size - 1, z_M); z += 1)\n",
" for (int z = z0_blk1; z <= MIN(z_M, z0_blk1 + z0_blk1_size - 1); z += 1)\n",
" {\n",
" u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 1][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 2][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 5][z + 4]/h_y) + (-8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 5][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 7][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 8][z + 4]/h_y) + (8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 1][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 3][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y) + (6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 3][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 6][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 7][z + 4]/h_y))*sin(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n",
" }\n",
Expand Down Expand Up @@ -1272,17 +1272,17 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size + 1, y_M + 2); y += 1, ys += 1)\n",
" for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size + 1); y += 1, ys += 1)\n",
" {\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
" r2[ys][z] = r1*(8.33333333e-2F*(u[t0][x + 4][y + 2][z + 4] - u[t0][x + 4][y + 6][z + 4]) + 6.66666667e-1F*(-u[t0][x + 4][y + 3][z + 4] + u[t0][x + 4][y + 5][z + 4]));\n",
" }\n",
" }\n",
" for (int y = y0_blk0, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1)\n",
" for (int y = y0_blk0, ys = 0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1)\n",
" {\n",
" #pragma omp simd aligned(f,u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand Down Expand Up @@ -1390,9 +1390,9 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" {\n",
" for (int yy = yii, ysi = (yy + ys + 2)%(5); yy <= 2; yy += 1, ysi = (yy + ys + 2)%(5))\n",
" {\n",
Expand Down Expand Up @@ -1706,9 +1706,9 @@
" {\n",
" for (int y0_blk0 = y_m - 2; y0_blk0 <= y_M + 2; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M + 2); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M + 2, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M + 2); y += 1)\n",
" for (int y = y0_blk0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size - 1); y += 1)\n",
" {\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand All @@ -1728,9 +1728,9 @@
" {\n",
" for (int y1_blk0 = y_m; y1_blk0 <= y_M; y1_blk0 += y1_blk0_size)\n",
" {\n",
" for (int x = x1_blk0; x <= MIN(x1_blk0 + x1_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x1_blk0; x <= MIN(x_M, x1_blk0 + x1_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y1_blk0; y <= MIN(y1_blk0 + y1_blk0_size - 1, y_M); y += 1)\n",
" for (int y = y1_blk0; y <= MIN(y_M, y1_blk0 + y1_blk0_size - 1); y += 1)\n",
" {\n",
" #pragma omp simd aligned(f,u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand Down Expand Up @@ -1788,7 +1788,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.16"
},
"varInspector": {
"cols": {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,6 +1,6 @@
pip>=9.0.1
numpy>1.16
sympy>=1.7,<1.12
sympy>=1.9,<1.12
scipy
flake8>=2.1.0
nbval
Expand Down

0 comments on commit 657857a

Please sign in to comment.