Skip to content

Commit

Permalink
Trivial map elimination init (#1353)
Browse files Browse the repository at this point in the history
The EliminateTrivialMap transformation fails to properly connect the
content of the map if it is a write only map and the map to be removed
is nested inside another map. This is as the current solution uses
memlet_path which does not work on the empty memlets going away from the
map entry.

This adds a testcase and solution which seems to work for my specific
needs but are not polished and might not work in all cases.

---------

Co-authored-by: Samuel Martin <[email protected]>
Co-authored-by: alexnick83 <[email protected]>
Co-authored-by: acalotoiu <[email protected]>
  • Loading branch information
4 people authored Nov 8, 2023
1 parent f3781b1 commit 3e8c74c
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 2 deletions.
13 changes: 11 additions & 2 deletions dace/transformation/dataflow/trivial_map_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dace.sdfg import utils as sdutil
from dace.transformation import transformation
from dace.properties import make_properties
from dace.memlet import Memlet


@make_properties
Expand Down Expand Up @@ -48,12 +49,15 @@ def apply(self, graph, sdfg):

if len(remaining_ranges) == 0:
# Redirect map entry's out edges
write_only_map = True
for edge in graph.out_edges(map_entry):
path = graph.memlet_path(edge)
index = path.index(edge)

# Add an edge directly from the previous source connector to the destination
graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data)
if not edge.data.is_empty():
# Add an edge directly from the previous source connector to the destination
graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data)
write_only_map = False

# Redirect map exit's in edges.
for edge in graph.in_edges(map_exit):
Expand All @@ -63,6 +67,11 @@ def apply(self, graph, sdfg):
# Add an edge directly from the source to the next destination connector
if len(path) > index + 1:
graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data)
if write_only_map:
outer_exit = path[index+1].dst
outer_entry = graph.entry_node(outer_exit)
if outer_entry is not None:
graph.add_edge(outer_entry, None, edge.src, None, Memlet())

# Remove map
graph.remove_nodes_from([map_entry, map_exit])
132 changes: 132 additions & 0 deletions tests/trivial_map_elimination_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,69 @@ def trivial_map_sdfg():
return sdfg


def trivial_map_init_sdfg():
sdfg = dace.SDFG('trivial_map_range_expanded')
sdfg.add_array('B', [5, 1], dace.float64)
state = sdfg.add_state()

# Nodes
map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5'))
map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1'))

tasklet = state.add_tasklet('tasklet', {}, {'b'}, 'b = 1')
write = state.add_write('B')

# Edges
state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet())
state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet())

state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b',
dst_conn='IN_B')
state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B',
dst_conn='IN_B')
state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'),
src_conn='OUT_B')

sdfg.validate()
return sdfg


def trivial_map_pseudo_init_sdfg():
sdfg = dace.SDFG('trivial_map_range_expanded')
sdfg.add_array('A', [5, 1], dace.float64)
sdfg.add_array('B', [5, 1], dace.float64)
state = sdfg.add_state()

# Nodes
map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5'))
map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1'))

read = state.add_read('A')
tasklet = state.add_tasklet('tasklet', {'a'}, {'b'}, 'b = a')
write = state.add_write('B')

# Edges
state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet())
state.add_memlet_path(read, map_entry_outer, map_entry_inner, memlet=dace.Memlet.simple('A', '0:5, 0'),
dst_conn='IN_A')
state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet())
state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet.simple('A', 'j, 0'), src_conn='OUT_A', dst_conn='a')

state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b',
dst_conn='IN_B')
state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B',
dst_conn='IN_B')
state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'),
src_conn='OUT_B')

sdfg.validate()
return sdfg


class TrivialMapEliminationTest(unittest.TestCase):
"""
Tests the case where the map has an empty input edge
"""
def test_can_be_applied(self):
graph = trivial_map_sdfg()

Expand Down Expand Up @@ -56,5 +118,75 @@ def test_raplaces_map_params_in_scope(self):
self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)]))


class TrivialMapInitEliminationTest(unittest.TestCase):
def test_can_be_applied(self):
graph = trivial_map_init_sdfg()

count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False)
graph.validate()

self.assertGreater(count, 0)

def test_removes_map(self):
graph = trivial_map_init_sdfg()

state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 2)

graph.apply_transformations(TrivialMapElimination)

state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 1)

def test_reconnects_edges(self):
graph = trivial_map_init_sdfg()

graph.apply_transformations(TrivialMapElimination)
state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 1)
# Check that there is an outgoing edge from the map entry
self.assertEqual(len(state.out_edges(map_entries[0])), 1)


class TrivialMapPseudoInitEliminationTest(unittest.TestCase):
"""
Test cases where the map has an empty input and a non empty input
"""
def test_can_be_applied(self):
graph = trivial_map_pseudo_init_sdfg()

count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False)
graph.validate()
graph.view()

self.assertGreater(count, 0)

def test_removes_map(self):
graph = trivial_map_pseudo_init_sdfg()

state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 2)

graph.apply_transformations(TrivialMapElimination)

state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 1)

def test_reconnects_edges(self):
graph = trivial_map_pseudo_init_sdfg()

graph.apply_transformations(TrivialMapElimination)
state = graph.nodes()[0]
map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)]
self.assertEqual(len(map_entries), 1)
# Check that there is an outgoing edge from the map entry
self.assertEqual(len(state.out_edges(map_entries[0])), 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3e8c74c

Please sign in to comment.