Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule Trees (2/3): Tree-to-SDFG conversion #1466

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 20 additions & 4 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from typing import (Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union)
import sympy as sp
import dace
from dace import dtypes
from dace import dtypes, symbolic
from dace.sdfg.state import SDFGState
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.graph import Edge
Expand Down Expand Up @@ -234,7 +234,7 @@ def as_cpp(self, codegen, symbols) -> str:
successor = self.elements[i + 1].first_state
elif i == len(self.elements) - 1:
# If last edge leads to first state in next block
next_block = _find_next_block(self)
next_block = _find_next_block(self)
if next_block is not None:
successor = next_block.first_state

Expand Down Expand Up @@ -372,8 +372,8 @@ def as_cpp(self, codegen, symbols) -> str:
init = self.itervar
else:
init = f'{symbols[self.itervar]} {self.itervar}'
init += ' = ' + unparse_interstate_edge(self.init_edges[0].data.assignments[self.itervar],
sdfg, codegen=codegen)
init += ' = ' + unparse_interstate_edge(
self.init_edges[0].data.assignments[self.itervar], sdfg, codegen=codegen)

preinit = ''
if self.init_edges:
Expand Down Expand Up @@ -405,6 +405,22 @@ def first_state(self) -> SDFGState:
def children(self) -> List[ControlFlow]:
return [self.body]

def loop_range(self) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]:
"""
For well-formed loops, returns a tuple of (start, end, stride). Otherwise, returns None.
"""
from dace.transformation.interstate.loop_detection import find_for_loop
sdfg = self.guard.parent
for e in sdfg.out_edges(self.guard):
if e.data.condition == self.condition:
break
else:
return None # Condition edge not found
result = find_for_loop(sdfg, self.guard, e.dst, self.itervar)
if result is None:
return None
return result[1]


@dataclass
class WhileScope(ControlFlow):
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def __init__(self,
if sym.name not in self.sdfg.symbols:
self.sdfg.add_symbol(sym.name, sym.dtype)
self.sdfg._temp_transients = tmp_idx
self.last_state = self.sdfg.add_state('init', is_start_state=True)
self.last_state = self.sdfg.add_state('init', is_start_block=True)

self.inputs: DependencyType = {}
self.outputs: DependencyType = {}
Expand Down
41 changes: 37 additions & 4 deletions dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from collections import defaultdict
import copy
from typing import Dict, List, Set
Expand Down Expand Up @@ -330,6 +330,29 @@ def remove_name_collisions(sdfg: SDFG):
nsdfg.replace_dict(replacements)


def create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot):
"""
Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes
data containers, symbols, constants, etc.

:param sdfg: The top-level SDFG to create the repository from.
:param stree: The tree root in which to make the unified descriptor repository.
"""
stree.containers = sdfg.arrays
stree.symbols = sdfg.symbols
stree.constants = sdfg.constants_prop

# Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of
# the nested SDFGs' descriptor repositories
for nsdfg in sdfg.all_sdfgs_recursive():
transients = {k: v for k, v in nsdfg.arrays.items() if v.transient}
symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols}
constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants}
stree.containers.update(transients)
stree.symbols.update(symbols)
stree.constants.update(constants)


def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str,
viewed_name: str) -> tn.ViewNode:
"""
Expand Down Expand Up @@ -619,7 +642,7 @@ def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]],
return result


def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope:
def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot:
"""
Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of
the SDFG.
Expand Down Expand Up @@ -653,7 +676,6 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True)
dealias_sdfg(sdfg)
# Handle name collisions (in arrays, state labels, symbols)
remove_name_collisions(sdfg)

#############################

# Create initial tree from CFG
Expand Down Expand Up @@ -737,7 +759,18 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche
return result

# Recursive traversal of the control flow tree
result = tn.ScheduleTreeScope(children=totree(cfg))
children = totree(cfg)

# Create the scope object
if toplevel:
# Create the root with the elements of the descriptor repository
result = tn.ScheduleTreeRoot(name=sdfg.name,
children=children,
arg_names=sdfg.arg_names,
callback_mapping=sdfg.callback_mapping)
create_unified_descriptor_repository(sdfg, result)
else:
result = tn.ScheduleTreeScope(children=children)

# Clean up tree
stpasses.remove_unused_and_duplicate_labels(result)
Expand Down
175 changes: 175 additions & 0 deletions dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import copy
from collections import defaultdict
from dace.memlet import Memlet
from dace.sdfg import nodes, memlet_utils as mmu
from dace.sdfg.sdfg import SDFG, ControlFlowRegion
from dace.sdfg.state import SDFGState
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from enum import Enum, auto
from typing import Dict, List, Set, Union


class StateBoundaryBehavior(Enum):
STATE_TRANSITION = auto() #: Creates multiple states with a state transition
EMPTY_MEMLET = auto() #: Happens-before empty memlet edges in the same state


def from_schedule_tree(stree: tn.ScheduleTreeRoot,
state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG:
"""
Converts a schedule tree into an SDFG.

:param stree: The schedule tree root to convert.
:param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write).
See the ``StateBoundaryBehavior`` enumeration for more details.
:return: An SDFG representing the schedule tree.
"""
# Set SDFG descriptor repository
result = SDFG(stree.name, propagate=False)
result.arg_names = copy.deepcopy(stree.arg_names)
result._arrays = copy.deepcopy(stree.containers)
result.constants_prop = copy.deepcopy(stree.constants)
result.symbols = copy.deepcopy(stree.symbols)

# TODO: Fill SDFG contents
stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc.

# TODO: create_state_boundary
# TODO: When creating a state boundary, include all inter-state assignments that precede it.
# TODO: create_loop_block
# TODO: create_conditional_block
# TODO: create_dataflow_scope

return result


def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot:
"""
Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary.
Operates in-place on the given schedule tree.

This happens when there is a:
* write-after-write dependency;
* write-after-read dependency that cannot be fulfilled via memlets;
* control flow block (for/if); or
* otherwise before a state label (which means a state transition could occur, e.g., in a gblock)

:param stree: The schedule tree to operate on.
"""

# Simple boundary node inserter for control flow blocks and state labels
class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer):

def visit_scope(self, scope: tn.ScheduleTreeScope):
if isinstance(scope, tn.ControlFlowScope):
return [tn.StateBoundaryNode(True), self.generic_visit(scope)]
return self.generic_visit(scope)

def visit_StateLabel(self, node: tn.StateLabel):
return [tn.StateBoundaryNode(True), self.generic_visit(node)]

# First, insert boundaries around labels and control flow
stree = SimpleStateBoundaryInserter().visit(stree)

# Then, insert boundaries after unmet memory dependencies or potential data races
_insert_memory_dependency_state_boundaries(stree)

return stree


def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope):
"""
Helper function that inserts boundaries after unmet memory dependencies.
"""
reads: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict()
writes: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict()
parents: Dict[int, Set[int]] = defaultdict(set)
boundaries_to_insert: List[int] = []

for i, n in enumerate(scope.children):
if isinstance(n, (tn.StateBoundaryNode, tn.ControlFlowScope)): # Clear state
reads.clear()
writes.clear()
parents.clear()
if isinstance(n, tn.ControlFlowScope): # Insert memory boundaries recursively
_insert_memory_dependency_state_boundaries(n)
continue

# If dataflow scope, insert state boundaries recursively and as a node
if isinstance(n, tn.DataflowScope):
_insert_memory_dependency_state_boundaries(n)

inputs = n.input_memlets()
outputs = n.output_memlets()

# Register reads
for inp in inputs:
if inp not in reads:
reads[inp] = [n]
else:
reads[inp].append(n)

# Transitively add parents
if inp in writes:
for parent in writes[inp]:
parents[id(n)].add(id(parent))
parents[id(n)].update(parents[id(parent)])

# Inter-state assignment nodes with reads necessitate a state transition if they were written to.
if isinstance(n, tn.AssignNode) and any(inp in writes for inp in inputs):
boundaries_to_insert.append(i)
reads.clear()
writes.clear()
parents.clear()
continue

# Write after write or potential write/write data race, insert state boundary
if any(o in writes and (o not in reads or any(id(r) not in parents for r in reads[o])) for o in outputs):
boundaries_to_insert.append(i)
reads.clear()
writes.clear()
parents.clear()
continue

# Potential read/write data race: if any read is not in the parents of this node, it might
# be performed in parallel
if any(o in reads and any(id(r) not in parents for r in reads[o]) for o in outputs):
boundaries_to_insert.append(i)
reads.clear()
writes.clear()
parents.clear()
continue

# Register writes after all hazards have been tested for
for out in outputs:
if out not in writes:
writes[out] = [n]
else:
writes[out].append(n)

# Insert memory dependency state boundaries in reverse in order to keep indices intact
for i in reversed(boundaries_to_insert):
scope.children.insert(i, tn.StateBoundaryNode())


#############################################################################
# SDFG content creation functions


def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowRegion, state: SDFGState,
behavior: StateBoundaryBehavior) -> SDFGState:
"""
Creates a boundary between two states

:param bnode: The state boundary node to generate.
:param sdfg_region: The control flow block in which to generate the boundary (e.g., SDFG).
:param state: The last state prior to this boundary.
:param behavior: The state boundary behavior with which to create the boundary.
:return: The newly created state.
"""
# TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every
# behavior. Fall back to state transition in that case.
scope: tn.ControlFlowScope = bnode.parent
assert scope is not None
pass