Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite loops in memlet path when a scope cycle is added #1559

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto

# Prepend incoming edges until reaching the source node
curedge = edge
visited = set()
while not isinstance(curedge.src, (nd.CodeNode, nd.AccessNode)):
visited.add(curedge)
# Trace through scopes using OUT_# -> IN_#
if isinstance(curedge.src, (nd.EntryNode, nd.ExitNode)):
if curedge.src_conn is None:
Expand All @@ -398,10 +400,14 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto
next_edge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == "IN_" + curedge.src_conn[4:])
result.insert(0, next_edge)
curedge = next_edge
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')

# Append outgoing edges until reaching the sink node
curedge = edge
visited.clear()
while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)):
visited.add(curedge)
# Trace through scope entry using IN_# -> OUT_#
if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)):
if curedge.dst_conn is None:
Expand All @@ -411,6 +417,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto
next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:])
result.append(next_edge)
curedge = next_edge
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')

return result

Expand All @@ -434,16 +442,23 @@ def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree:

# Find tree root
curedge = edge
visited = set()
if propagate_forward:
while (isinstance(curedge.src, nd.EntryNode) and curedge.src_conn is not None):
visited.add(curedge)
assert curedge.src_conn.startswith('OUT_')
cname = curedge.src_conn[4:]
curedge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == 'IN_%s' % cname)
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')
elif propagate_backward:
while (isinstance(curedge.dst, nd.ExitNode) and curedge.dst_conn is not None):
visited.add(curedge)
assert curedge.dst_conn.startswith('IN_')
cname = curedge.dst_conn[3:]
curedge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == 'OUT_%s' % cname)
if curedge in visited:
raise ValueError('Cycle encountered while reading memlet path')
tree_root = mm.MemletTree(curedge, downwards=propagate_forward)

# Collect children (recursively)
Expand Down
19 changes: 19 additions & 0 deletions tests/sdfg/cycles_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import pytest

import dace
Expand All @@ -13,3 +14,21 @@ def test_cycles():

state.add_edge(access, None, access, None, dace.Memlet.simple("A", "0"))
sdfg.validate()


def test_cycles_memlet_path():
with pytest.raises(ValueError, match="Found cycles.*"):
sdfg = dace.SDFG("foo")
state = sdfg.add_state()
sdfg.add_array("bla", shape=(10, ), dtype=dace.float32)
mentry_3, _ = state.add_map("map_3", dict(i="0:9"))
mentry_3.add_in_connector("IN_0")
mentry_3.add_out_connector("OUT_0")
state.add_edge(mentry_3, "OUT_0", mentry_3, "IN_0", dace.Memlet(data="bla", subset='0:9'))

sdfg.validate()


if __name__ == '__main__':
test_cycles()
test_cycles_memlet_path()