From 395308bff19af5866906187273591c8d6e39e7dc Mon Sep 17 00:00:00 2001 From: Samuel Martin Date: Mon, 21 Aug 2023 15:33:47 +0200 Subject: [PATCH 1/4] Added testcase exposing bug --- tests/trivial_map_elimination_test.py | 61 +++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 44b1f77652..bdb73bc00a 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -25,6 +25,33 @@ def trivial_map_sdfg(): return sdfg +def trivial_map_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + tasklet = state.add_tasklet('tasklet', {}, {'b'}, 'b = 1') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + class TrivialMapEliminationTest(unittest.TestCase): def test_can_be_applied(self): graph = trivial_map_sdfg() @@ -56,5 +83,39 @@ def test_raplaces_map_params_in_scope(self): self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) +class TrivialMapInitEliminationTest(unittest.TestCase): + def test_can_be_applied(self): + graph = trivial_map_init_sdfg() + graph.save('trivial_map_expanded_sdfg.sdfg') + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.save('trivial_map_expanded_sdfg_applied.sdfg') + graph.validate() + + self.assertGreater(count, 0) + + # def test_removes_map(self): + # graph = trivial_map_sdfg() + + # graph.apply_transformations(TrivialMapElimination) + + # state = graph.nodes()[0] + # map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + # self.assertEqual(len(map_entries), 0) + + # def test_raplaces_map_params_in_scope(self): + # # Tests if the 'i' in the range of the memlet to B gets replaced + # # with the value 'i' obtains in the map, namely '0'. + + # graph = trivial_map_sdfg() + + # graph.apply_transformations(TrivialMapElimination) + + # state = graph.nodes()[0] + # B = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.AccessNode) and n.data == 'B'][0] + # out_memlet = state.in_edges(B)[0] + # self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) + + if __name__ == '__main__': unittest.main() From cc670417985e19ba67259f44955805dcb769862d Mon Sep 17 00:00:00 2001 From: Samuel Martin Date: Mon, 21 Aug 2023 16:10:34 +0200 Subject: [PATCH 2/4] Added a fix, not sure how good he is --- .../dataflow/trivial_map_elimination.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/trivial_map_elimination.py b/dace/transformation/dataflow/trivial_map_elimination.py index 327d5d8c9a..fe9529c197 100644 --- a/dace/transformation/dataflow/trivial_map_elimination.py +++ b/dace/transformation/dataflow/trivial_map_elimination.py @@ -5,6 +5,7 @@ from dace.sdfg import utils as sdutil from dace.transformation import transformation from dace.properties import make_properties +from dace.memlet import Memlet @make_properties @@ -48,12 +49,16 @@ def apply(self, graph, sdfg): if len(remaining_ranges) == 0: # Redirect map entry's out edges + write_only_map = False for edge in graph.out_edges(map_entry): path = graph.memlet_path(edge) index = path.index(edge) - # Add an edge directly from the previous source connector to the destination - graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + if len(path) > 1: + # Add an edge directly from the previous source connector to the destination + graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + else: + write_only_map = True # Redirect map exit's in edges. for edge in graph.in_edges(map_exit): @@ -63,6 +68,11 @@ def apply(self, graph, sdfg): # Add an edge directly from the source to the next destination connector if len(path) > index + 1: graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data) + if write_only_map: + outer_exit = path[index+1].dst + outer_entry = graph.entry_node(outer_exit) + if outer_entry is not None: + graph.add_edge(outer_entry, None, edge.src, None, Memlet()) # Remove map graph.remove_nodes_from([map_entry, map_exit]) From 9b967af9b0a5ff8b65f7c4366c55adb984c73476 Mon Sep 17 00:00:00 2001 From: Samuel Martin Date: Mon, 21 Aug 2023 17:42:23 +0200 Subject: [PATCH 3/4] Adapt other testcases --- tests/trivial_map_elimination_test.py | 35 +++++++++++++-------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index bdb73bc00a..2331f76977 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -86,35 +86,34 @@ def test_raplaces_map_params_in_scope(self): class TrivialMapInitEliminationTest(unittest.TestCase): def test_can_be_applied(self): graph = trivial_map_init_sdfg() - graph.save('trivial_map_expanded_sdfg.sdfg') count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) - graph.save('trivial_map_expanded_sdfg_applied.sdfg') graph.validate() self.assertGreater(count, 0) - # def test_removes_map(self): - # graph = trivial_map_sdfg() - - # graph.apply_transformations(TrivialMapElimination) + def test_removes_map(self): + graph = trivial_map_init_sdfg() - # state = graph.nodes()[0] - # map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] - # self.assertEqual(len(map_entries), 0) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) - # def test_raplaces_map_params_in_scope(self): - # # Tests if the 'i' in the range of the memlet to B gets replaced - # # with the value 'i' obtains in the map, namely '0'. + graph.apply_transformations(TrivialMapElimination) - # graph = trivial_map_sdfg() + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) - # graph.apply_transformations(TrivialMapElimination) + def test_reconnects_edges(self): + graph = trivial_map_init_sdfg() - # state = graph.nodes()[0] - # B = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.AccessNode) and n.data == 'B'][0] - # out_memlet = state.in_edges(B)[0] - # self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) if __name__ == '__main__': From cdba4e45ce95fcaf2c515f00e191096432a8cd31 Mon Sep 17 00:00:00 2001 From: Samuel Martin Date: Tue, 10 Oct 2023 15:33:46 +0200 Subject: [PATCH 4/4] Review changes --- .../dataflow/trivial_map_elimination.py | 7 +- tests/trivial_map_elimination_test.py | 72 +++++++++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/dace/transformation/dataflow/trivial_map_elimination.py b/dace/transformation/dataflow/trivial_map_elimination.py index fe9529c197..9387cfce23 100644 --- a/dace/transformation/dataflow/trivial_map_elimination.py +++ b/dace/transformation/dataflow/trivial_map_elimination.py @@ -49,16 +49,15 @@ def apply(self, graph, sdfg): if len(remaining_ranges) == 0: # Redirect map entry's out edges - write_only_map = False + write_only_map = True for edge in graph.out_edges(map_entry): path = graph.memlet_path(edge) index = path.index(edge) - if len(path) > 1: + if not edge.data.is_empty(): # Add an edge directly from the previous source connector to the destination graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) - else: - write_only_map = True + write_only_map = False # Redirect map exit's in edges. for edge in graph.in_edges(map_exit): diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 2331f76977..9600dad640 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -52,7 +52,42 @@ def trivial_map_init_sdfg(): return sdfg +def trivial_map_pseudo_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('A', [5, 1], dace.float64) + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + read = state.add_read('A') + tasklet = state.add_tasklet('tasklet', {'a'}, {'b'}, 'b = a') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(read, map_entry_outer, map_entry_inner, memlet=dace.Memlet.simple('A', '0:5, 0'), + dst_conn='IN_A') + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet.simple('A', 'j, 0'), src_conn='OUT_A', dst_conn='a') + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + class TrivialMapEliminationTest(unittest.TestCase): + """ + Tests the case where the map has an empty input edge + """ def test_can_be_applied(self): graph = trivial_map_sdfg() @@ -116,5 +151,42 @@ def test_reconnects_edges(self): self.assertEqual(len(state.out_edges(map_entries[0])), 1) +class TrivialMapPseudoInitEliminationTest(unittest.TestCase): + """ + Test cases where the map has an empty input and a non empty input + """ + def test_can_be_applied(self): + graph = trivial_map_pseudo_init_sdfg() + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.validate() + graph.view() + + self.assertGreater(count, 0) + + def test_removes_map(self): + graph = trivial_map_pseudo_init_sdfg() + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + + def test_reconnects_edges(self): + graph = trivial_map_pseudo_init_sdfg() + + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) + + if __name__ == '__main__': unittest.main()