From a97de7769a6fd614f7e5665567285b82bba2b1f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lukas=20Tr=C3=BCmper?= Date: Fri, 19 Aug 2022 08:23:12 +0200 Subject: [PATCH] Additional local storage tests --- tests/transformations/local_storage_test.py | 218 ++++++++++++++++++++ 1 file changed, 218 insertions(+) diff --git a/tests/transformations/local_storage_test.py b/tests/transformations/local_storage_test.py index 1965897263..c39fd9b807 100644 --- a/tests/transformations/local_storage_test.py +++ b/tests/transformations/local_storage_test.py @@ -3,10 +3,223 @@ import dace import numpy as np from dace.transformation.dataflow import MapTiling, OutLocalStorage +from dace.transformation.dataflow.local_storage import InLocalStorage + +import dace.transformation.helpers as xfh N = dace.symbol('N') +@dace.program +def copy_sdfg(A: dace.float32[N, N], B: dace.float32[N, N]): + for i, j in dace.map[0:N, 0:N]: + with dace.tasklet: + a << A[i, j] + b >> B[i, j] + b = a + + +def find_map_entries(sdfg): + outer_map_entry = None + inner_map_entry = None + for node in sdfg.start_state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + if xfh.get_parent_map(sdfg.start_state, node) is None: + assert outer_map_entry is None + outer_map_entry = node + else: + assert inner_map_entry is None + inner_map_entry = node + assert not outer_map_entry is None + assert not inner_map_entry is None + + return outer_map_entry, inner_map_entry + + +def test_in_local_storage_explicit(): + sdfg = copy_sdfg.to_sdfg() + sdfg.simplify() + + sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}]) + + outer_map_entry, inner_map_entry = find_map_entries(sdfg) + + InLocalStorage.apply_to(sdfg=sdfg, + node_a=outer_map_entry, + node_b=inner_map_entry, + options={ + "array": "A", + "create_array": True, + "prefix": "loc_" + }, + save=True) + + # Finding relevant node + local_storage_node = None + for node in sdfg.start_state.nodes(): + if not isinstance(node, dace.nodes.AccessNode): + continue + + if node.data == "loc_A": + assert local_storage_node is None + local_storage_node = node + break + + assert not local_storage_node is None + + # Check transient array created + trans_array = local_storage_node.data + assert trans_array in sdfg.arrays + + # Check properties + desc = sdfg.arrays[local_storage_node.data] + assert desc.shape == (8, 8) + assert desc.transient == True + + # Check array was set correctly + serialized = sdfg.transformation_hist[0].to_json() + assert serialized["array"] == "A" + + +def test_in_local_storage_implicit(): + sdfg = copy_sdfg.to_sdfg() + sdfg.simplify() + + sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}]) + + outer_map_entry, inner_map_entry = find_map_entries(sdfg) + + InLocalStorage.apply_to(sdfg=sdfg, + node_a=outer_map_entry, + node_b=inner_map_entry, + options={ + "create_array": True, + "prefix": "loc_" + }, + save=True) + + # Finding relevant node + local_storage_node = None + for node in sdfg.start_state.nodes(): + if not isinstance(node, dace.nodes.AccessNode): + continue + + if node.data == "loc_A": + assert local_storage_node is None + local_storage_node = node + break + + assert not local_storage_node is None + + # Check transient array created + trans_array = local_storage_node.data + assert trans_array in sdfg.arrays + + # Check properties + desc = sdfg.arrays[local_storage_node.data] + assert desc.shape == (8, 8) + assert desc.transient == True + + # Check array was set correctly + serialized = sdfg.transformation_hist[0].to_json() + assert serialized["array"] == None + + +def test_out_local_storage_explicit(): + sdfg = copy_sdfg.to_sdfg() + sdfg.simplify() + + sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}]) + + outer_map_entry, inner_map_entry = find_map_entries(sdfg) + outer_map_exit = sdfg.start_state.exit_node(outer_map_entry) + inner_map_exit = sdfg.start_state.exit_node(inner_map_entry) + + OutLocalStorage.apply_to(sdfg=sdfg, + node_a=inner_map_exit, + node_b=outer_map_exit, + options={ + "array": "B", + "create_array": True, + "prefix": "loc_" + }, + save=True) + + # Finding relevant node + local_storage_node = None + for node in sdfg.start_state.nodes(): + if not isinstance(node, dace.nodes.AccessNode): + continue + + if node.data == "loc_B": + assert local_storage_node is None + local_storage_node = node + break + + assert not local_storage_node is None + + # Check transient array created + trans_array = local_storage_node.data + assert trans_array in sdfg.arrays + + # Check properties + desc = sdfg.arrays[local_storage_node.data] + assert desc.shape == (8, 8) + assert desc.transient == True + + # Check array was set correctly + serialized = sdfg.transformation_hist[0].to_json() + assert serialized["array"] == "B" + + +def test_out_local_storage_implicit(): + sdfg = copy_sdfg.to_sdfg() + sdfg.simplify() + + sdfg.apply_transformations([MapTiling], options=[{"tile_sizes": [8]}]) + + outer_map_entry, inner_map_entry = find_map_entries(sdfg) + outer_map_exit = sdfg.start_state.exit_node(outer_map_entry) + inner_map_exit = sdfg.start_state.exit_node(inner_map_entry) + + OutLocalStorage.apply_to(sdfg=sdfg, + node_a=inner_map_exit, + node_b=outer_map_exit, + options={ + "create_array": True, + "prefix": "loc_" + }, + save=True) + + # Finding relevant node + local_storage_node = None + for node in sdfg.start_state.nodes(): + if not isinstance(node, dace.nodes.AccessNode): + continue + + if node.data == "loc_B": + assert local_storage_node is None + local_storage_node = node + break + + assert not local_storage_node is None + + # Check transient array created + trans_array = local_storage_node.data + assert trans_array in sdfg.arrays + + # Check properties + desc = sdfg.arrays[local_storage_node.data] + assert desc.shape == (8, 8) + assert desc.transient == True + + # Check array was set correctly + serialized = sdfg.transformation_hist[0].to_json() + assert serialized["array"] == None + + @dace.program def arange(): out = np.ndarray([N], np.int32) @@ -18,6 +231,7 @@ def arange(): class LocalStorageTests(unittest.TestCase): + def test_even(self): sdfg = arange.to_sdfg() sdfg.apply_transformations([MapTiling, OutLocalStorage], options=[{'tile_sizes': [8]}, {}]) @@ -37,3 +251,7 @@ def test_uneven(self): if __name__ == '__main__': unittest.main() + test_in_local_storage_explicit() + test_in_local_storage_implicit() + test_out_local_storage_explicit() + test_out_local_storage_implicit()