Skip to content

Commit

Permalink
Fix stage_mem logic for which expressions should use the staged memory (
Browse files Browse the repository at this point in the history
#614)

- Only `stage_mem` accesses that always contained with the window
expression. Don't `stage_mem` accesses that are always disjoint from the
window expression. If there are accesses which are sometimes contained
and sometimes disjoint, raise a SchedulingError. This allows a more
precise analysis of read/write/reduced buffers.
- If all the replace of existing accesses fail, `stage_mem` should fail
instead of allocating a new buffer unnecessarily.
- Deprecate Check_BufferRW

Closes #446

---------

Co-authored-by: Kevin Qian <keqian@mit.edu>
Co-authored-by: Kevin Qian <52479696+skeqiqevian@users.noreply.github.com>
  • Loading branch information
3 people committed May 6, 2024
1 parent 81c30ba commit f892fd5
Show file tree
Hide file tree
Showing 11 changed files with 455 additions and 72 deletions.
157 changes: 109 additions & 48 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
Check_DeleteConfigWrite,
Check_ExtendEqv,
Check_ExprEqvInContext,
Check_BufferRW,
Check_BufferReduceOnly,
Check_Bounds,
Check_Access_In_Window,
Check_IsDeadAfter,
Check_IsIdempotent,
Check_ExprBound,
Expand All @@ -42,6 +42,8 @@
from .pattern_match import match_pattern
from .memory import DRAM

from functools import partial

# --------------------------------------------------------------------------- #
# --------------------------------------------------------------------------- #
# Wrapper for LoopIR_Rewrite for scheduling directives which takes procedure cursor
Expand Down Expand Up @@ -196,45 +198,61 @@ def _replace_pats(ir, fwd, c, pat, repl, only_replace_attrs=True, use_sym_id=Tru
# TODO: consider the implications of composing O(n) forwarding functions.
# will we need a special data structure? A chunkier operation for
# multi-way replacement?
cur_fwd = lambda x: x
c = fwd(c)
todos = []
for rd in match_pattern(c, pat, use_sym_id=use_sym_id):
if c_repl := repl(rd):
todos.append((rd, c_repl))

cur_fwd = lambda x: x
for (rd, c_repl) in todos:
rd = cur_fwd(rd)
if not (c_repl := repl(rd)):
continue
ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs)
cur_fwd = _compose(fwd_rd, cur_fwd)
return ir, _compose(cur_fwd, fwd)


def _replace_reads(ir, fwd, c, sym, repl, only_replace_attrs=True):
cur_fwd = lambda x: x
c = fwd(c)
todos = []
for rd in match_pattern(c, f"{repr(sym)}[_]", use_sym_id=True):
# Need [_] to pattern match against window expressions
if c_repl := repl(rd):
todos.append((rd, c_repl))

cur_fwd = lambda x: x
for (rd, c_repl) in todos:
rd = cur_fwd(rd)
if not (c_repl := repl(rd)):
continue
ir, fwd_rd = _replace_helper(rd, c_repl, only_replace_attrs)
cur_fwd = _compose(fwd_rd, cur_fwd)
return ir, _compose(cur_fwd, fwd)


def _replace_writes(ir, fwd, c, sym, repl, only_replace_attrs=True):
cur_fwd = lambda x: x
def _replace_writes(
ir, fwd, c, sym, repl, only_replace_attrs=True, match_assign=True, match_reduce=True
):
c = fwd(c)

# TODO: Consider optimizing to just one call of [match_pattern]
matches = match_pattern(c, f"{repr(sym)} = _", use_sym_id=True) + match_pattern(
c, f"{repr(sym)} += _", use_sym_id=True
)
matches = []
if match_assign:
matches = match_pattern(c, f"{repr(sym)} = _", use_sym_id=True)
if match_reduce:
matches = matches + match_pattern(c, f"{repr(sym)} += _", use_sym_id=True)

todos = []
for block in matches:
assert len(block) == 1 # match_pattern on stmts return blocks
s = cur_fwd(block[0])
if not (c_repl := repl(s)):
continue
s = block[0]
if c_repl := repl(s):
todos.append((s, c_repl))

cur_fwd = lambda x: x
for (s, c_repl) in todos:
s = cur_fwd(s)
ir, fwd_s = _replace_helper(s, c_repl, only_replace_attrs)
cur_fwd = _compose(fwd_s, cur_fwd)

return ir, _compose(cur_fwd, fwd)


Expand Down Expand Up @@ -3818,7 +3836,6 @@ def mk_write(c):


def DoStageMem(block_cursor, buf_name, w_exprs, new_name, use_accum_zero=False):
proc = block_cursor.get_root()
new_name = Sym(new_name)

def get_typ_mem():
Expand Down Expand Up @@ -3869,7 +3886,9 @@ def off_w(w, off):
pt = LoopIR.BinOp("-", w.pt, off, T.index, w.srcinfo)
return LoopIR.Point(pt, w.srcinfo)

return [off_w(w_i, w_e[0]) for w_i, w_e in zip(w_idx, w_exprs)]
w_los = [w_e[0] if isinstance(w_e, tuple) else w_e for w_e in w_exprs]

return [off_w(w_i, w_e) for w_i, w_e in zip(w_idx, w_los)]

ir = block_cursor.get_root()
block = [s._node for s in block_cursor]
Expand Down Expand Up @@ -3938,8 +3957,68 @@ def guard_wrapper(body):

return ir, fwd

isR, isW = Check_BufferRW(ir, block, buf_name, n_dims)
if isR:
def idx_contained_by_window(idx, block_cursor):
"""
Returns True if idx always lies in staged window range.
Returns False if idx never lies in staged window range.
Otherwise, will raise a SchedulingError.
"""
p = idx.get_root()
return Check_Access_In_Window(p, idx, w_exprs, block_cursor)

actualR = actualW = False
WShadow = False
# Conservatively, shadowing logic only works for single element staging windows.
w_is_pt = all(not isinstance(w, tuple) for w in w_exprs)

def mk_read(c, block_cursor):
nonlocal actualR
rd = c._node

if isinstance(rd, LoopIR.Read):
if idx_contained_by_window(c, block_cursor):
_idx = rewrite_idx(rd.idx)
actualR = True
return {"name": new_name, "idx": _idx}
elif isinstance(rd, LoopIR.WindowExpr):
if any(
isinstance(w, LoopIR.Interval) and not isinstance(w_e, tuple)
for w, w_e in zip(rd.idx, w_exprs)
):
raise SchedulingError(
f"Existing WindowExpr {rd} has a widnowed dimension which is not windowed in the new staged window."
)

if idx_contained_by_window(c, block_cursor):
_idx = rewrite_win(rd.idx)
_typ = T.Window(new_typ, rd.type.as_tensor, new_name, _idx)
actualR = True
return {"name": new_name, "idx": _idx, "type": _typ}

def mk_write(c, block_cursor):
nonlocal actualR
nonlocal actualW
nonlocal WShadow
s = c._node
if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)):
if idx_contained_by_window(c, block_cursor):
actualW = True
if isinstance(s, LoopIR.Reduce):
actualR = True
if not actualR and w_is_pt:
WShadow = True
return {"name": new_name, "idx": rewrite_idx(s.idx)}

for c in block_cursor:
ir, fwd = _replace_reads(
ir, fwd, c, buf_name, partial(mk_read, block_cursor=fwd(block_cursor))
)

ir, fwd = _replace_writes(
ir, fwd, c, buf_name, partial(mk_write, block_cursor=fwd(block_cursor))
)

if actualR and not WShadow:
load_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)]
load_widx = [LoopIR.Read(s, [], T.index, srcinfo) for s in load_iter]
if use_accum_zero:
Expand Down Expand Up @@ -3980,7 +4059,8 @@ def guard_wrapper(body):
ir, fwd = insert_safety_guards(
ir, fwd, get_inner_stmt(load_nest_c), load_rhs, buf_typ
)
if isW:

if actualW:
store_iter = [Sym(f"i{i}") for i, _ in enumerate(shape)]
store_ridx = [LoopIR.Read(s, [], T.index, srcinfo) for s in store_iter]
cp_store_ridx = store_ridx.copy()
Expand Down Expand Up @@ -4020,38 +4100,19 @@ def guard_wrapper(body):
ir, fwd, store_stmt_c, store_stmt_c._node, buf_typ
)

def mk_read(c):
rd = c._node
if isinstance(rd, LoopIR.Read):
return {
"name": new_name,
"idx": rewrite_idx(rd.idx),
"type": rd.type, # non-ideal, but easiest for now
}
elif isinstance(rd, LoopIR.WindowExpr):
w_idx = rewrite_win(rd.idx)
return {
"name": new_name,
"idx": w_idx,
"type": T.Window(new_typ, rd.type.as_tensor, new_name, w_idx),
}

def mk_write(c):
s = c._node
return {"name": new_name, "idx": rewrite_idx(s.idx)}

for c in block_cursor:
ir, fwd = _replace_reads(ir, fwd, c, buf_name, mk_read)
ir, fwd = _replace_writes(ir, fwd, c, buf_name, mk_write)

# new alloc, load_nest + new_body + store_nest
new_block_c = fwd(block_cursor[0]).as_block().expand(0, len(block_cursor) - 1)
if isR:
if actualR and not WShadow:
new_block_c = new_block_c.expand(1, 0)
if isW:
if actualW:
new_block_c = new_block_c.expand(0, 1)
alloc_c = new_block_c[0].prev()
Check_Bounds(ir, alloc_c._node, [c._node for c in new_block_c])
if not actualR and not actualW:
raise SchedulingError(
f"Cannot stage '{buf_name}' with the given window shape. Wrong window shape, or '{buf_name}' not accessed in the given scope?"
)

Check_Bounds(ir, new_alloc[0], [c._node for c in new_block_c])

return ir, fwd


Expand Down
9 changes: 9 additions & 0 deletions src/exo/internal_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class Block(Cursor):
def parent(self) -> Node:
return self._anchor

def depth(self) -> int:
return self._anchor.depth()

def before(self) -> Gap:
return self[0].before()

Expand Down Expand Up @@ -611,6 +614,9 @@ def parent(self) -> Node:
raise InvalidCursorError("cursor does not have a parent")
return Node(self._root, self._path[:-1])

def depth(self) -> int:
return len(self._path)

def before(self) -> Gap:
return Gap(self._root, self, GapType.Before)

Expand Down Expand Up @@ -772,6 +778,9 @@ def parent(self) -> Node:
return self._anchor
return self._anchor.parent()

def depth(self) -> int:
return self._anchor.depth()

def anchor(self) -> Node:
return self._anchor

Expand Down

0 comments on commit f892fd5

Please sign in to comment.