diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 0a93d54c2c..cafea3d754 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -389,7 +389,9 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto # Prepend incoming edges until reaching the source node curedge = edge + visited = set() while not isinstance(curedge.src, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scopes using OUT_# -> IN_# if isinstance(curedge.src, (nd.EntryNode, nd.ExitNode)): if curedge.src_conn is None: @@ -398,10 +400,14 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == "IN_" + curedge.src_conn[4:]) result.insert(0, next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') # Append outgoing edges until reaching the sink node curedge = edge + visited.clear() while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scope entry using IN_# -> OUT_# if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)): if curedge.dst_conn is None: @@ -411,6 +417,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:]) result.append(next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') return result @@ -434,16 +442,23 @@ def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: # Find tree root curedge = edge + visited = set() if propagate_forward: while (isinstance(curedge.src, nd.EntryNode) and curedge.src_conn is not None): + visited.add(curedge) assert curedge.src_conn.startswith('OUT_') cname = curedge.src_conn[4:] curedge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == 'IN_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') elif propagate_backward: while (isinstance(curedge.dst, nd.ExitNode) and curedge.dst_conn is not None): + visited.add(curedge) assert curedge.dst_conn.startswith('IN_') cname = curedge.dst_conn[3:] curedge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == 'OUT_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') tree_root = mm.MemletTree(curedge, downwards=propagate_forward) # Collect children (recursively) diff --git a/tests/sdfg/cycles_test.py b/tests/sdfg/cycles_test.py index 5e94db2eb4..480392ab2d 100644 --- a/tests/sdfg/cycles_test.py +++ b/tests/sdfg/cycles_test.py @@ -1,3 +1,4 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace @@ -13,3 +14,21 @@ def test_cycles(): state.add_edge(access, None, access, None, dace.Memlet.simple("A", "0")) sdfg.validate() + + +def test_cycles_memlet_path(): + with pytest.raises(ValueError, match="Found cycles.*"): + sdfg = dace.SDFG("foo") + state = sdfg.add_state() + sdfg.add_array("bla", shape=(10, ), dtype=dace.float32) + mentry_3, _ = state.add_map("map_3", dict(i="0:9")) + mentry_3.add_in_connector("IN_0") + mentry_3.add_out_connector("OUT_0") + state.add_edge(mentry_3, "OUT_0", mentry_3, "IN_0", dace.Memlet(data="bla", subset='0:9')) + + sdfg.validate() + + +if __name__ == '__main__': + test_cycles() + test_cycles_memlet_path()