Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[venom]: common subexpression elimination pass #4241

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
33ef4db
start cse
HodanPlodky Sep 4, 2024
52a3470
cse start
HodanPlodky Sep 4, 2024
c2d60ed
only one inst handling
HodanPlodky Sep 9, 2024
3087650
cleanup + fix
HodanPlodky Sep 10, 2024
bad3f69
effects start
HodanPlodky Sep 12, 2024
9798a39
correct compare for joins
HodanPlodky Sep 13, 2024
ea92aa0
fix for calls and instruction without output (i.e. stores)
HodanPlodky Sep 13, 2024
c754574
small cleanup
HodanPlodky Sep 16, 2024
e6a8551
small cleanup
HodanPlodky Sep 16, 2024
6136f83
basic test created and fix small fix
HodanPlodky Sep 16, 2024
2d7e1a4
clean up of the debug
HodanPlodky Sep 16, 2024
4b62958
Update vyper/venom/analysis/available_expression.py - better hash method
HodanPlodky Sep 16, 2024
fbf2964
Update vyper/venom/analysis/available_expression.py - Incorrect gramm…
HodanPlodky Sep 16, 2024
86977f5
fix for some error in cancun version of experimental codegen test
HodanPlodky Sep 17, 2024
49e6949
fix for log opcode and lint
HodanPlodky Sep 17, 2024
de3c602
handling the different size of the expressions
HodanPlodky Sep 18, 2024
ae98b46
Merge branch 'vyperlang:master' into feat/cse
HodanPlodky Sep 22, 2024
b286eee
fixes in handling effects + bigger expression
HodanPlodky Sep 23, 2024
14b8963
fixes in handling effects
HodanPlodky Sep 24, 2024
e957b04
fixes (to better version) and style changes
HodanPlodky Sep 25, 2024
63f27fe
perf fixes + better order
HodanPlodky Sep 26, 2024
9ce9dd8
quick fix just for dft but since it is being rework it is just hot fix
HodanPlodky Sep 26, 2024
a3585ed
Merge branch 'master' into feat/cse
HodanPlodky Sep 30, 2024
775f0a9
created possibility to change size of the expression at pass level
HodanPlodky Sep 30, 2024
ea944e6
created possibility to change size of the expression at pass level (l…
HodanPlodky Sep 30, 2024
ca23cde
Merge branch 'master' into feat/cse
HodanPlodky Oct 5, 2024
018b6a9
fixes after merge
HodanPlodky Oct 5, 2024
94a9a6e
better results this order
HodanPlodky Oct 6, 2024
7d5b6aa
used new effects data structure to reduce duplication of code
HodanPlodky Oct 6, 2024
2de7731
add test for commutative instructions
harkal Oct 10, 2024
9efaea1
simplification of `same()` and replacement by `__eq__`
harkal Oct 10, 2024
22e0627
Merge branch 'master' into feat/cse
harkal Oct 10, 2024
4339697
different branches test
HodanPlodky Oct 10, 2024
8f76730
cleanup tests
harkal Oct 10, 2024
30ed887
add `same()` to `IROperand`
harkal Oct 10, 2024
2d33dc7
fix equality of expressions
harkal Oct 10, 2024
866ee62
Merge branch 'master' into feat/cse
harkal Oct 10, 2024
145ee31
comments and fix of the hash function so it should hold that if x and…
HodanPlodky Oct 10, 2024
f3c66d0
Merge branch 'master' into feat/cse
harkal Oct 10, 2024
6f9800f
bit more comments
HodanPlodky Oct 10, 2024
c6a825a
lint
HodanPlodky Oct 10, 2024
674577a
comments and removed some unnecessery code
HodanPlodky Oct 11, 2024
479271e
lint
HodanPlodky Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/unit/compiler/venom/test_common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest

from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.context import IRContext
from vyper.venom.passes.common_subexpression_elimination import CSE
from vyper.venom.passes.store_expansion import StoreExpansionPass


def test_common_subexpression_elimination():
ctx = IRContext()
fn = ctx.create_function("test")
bb = fn.get_basic_block()
op = bb.append_instruction("store", 10)
sum_1 = bb.append_instruction("add", op, 10)
bb.append_instruction("mul", sum_1, 10)
sum_2 = bb.append_instruction("add", op, 10)
bb.append_instruction("mul", sum_2, 10)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

CSE(ac, fn).run_pass(1, 5)

assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 1, "wrong number of adds"
assert sum(1 for inst in bb.instructions if inst.opcode == "mul") == 1, "wrong number of muls"


def test_common_subexpression_elimination_effects_1():
ctx = IRContext()
fn = ctx.create_function("test")
bb = fn.get_basic_block()
mload_1 = bb.append_instruction("mload", 0)
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_2 = bb.append_instruction("mload", 0)
bb.append_instruction("add", mload_1, 10)
bb.append_instruction("add", mload_2, 10)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

CSE(ac, fn).run_pass()

assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 2, "wrong number of adds"


# This is a limitation of current implementation
@pytest.mark.xfail
def test_common_subexpression_elimination_effects_2():
ctx = IRContext()
fn = ctx.create_function("test")
bb = fn.get_basic_block()
mload_1 = bb.append_instruction("mload", 0)
bb.append_instruction("add", mload_1, 10)
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_2 = bb.append_instruction("mload", 0)
bb.append_instruction("add", mload_1, 10)
bb.append_instruction("add", mload_2, 10)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
CSE(ac, fn).run_pass()

assert sum(1 for inst in bb.instructions if inst.opcode == "add") == 2, "wrong number of adds"


def test_common_subexpression_elimination_effect_mstore():
ctx = IRContext()
fn = ctx.create_function("test")
bb = fn.get_basic_block()
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_1 = bb.append_instruction("mload", 0)
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_2 = bb.append_instruction("mload", 0)
bb.append_instruction("add", mload_1, mload_2)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

StoreExpansionPass(ac, fn).run_pass()
CSE(ac, fn).run_pass(1, 5)

assert (
sum(1 for inst in bb.instructions if inst.opcode == "mstore") == 1
), "wrong number of mstores"
assert (
sum(1 for inst in bb.instructions if inst.opcode == "mload") == 1
), "wrong number of mloads"


def test_common_subexpression_elimination_effect_mstore_with_msize():
ctx = IRContext()
fn = ctx.create_function("test")
bb = fn.get_basic_block()
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_1 = bb.append_instruction("mload", 0)
op = bb.append_instruction("store", 10)
bb.append_instruction("mstore", op, 0)
mload_2 = bb.append_instruction("mload", 0)
msize_read = bb.append_instruction("msize")
bb.append_instruction("add", mload_1, msize_read)
bb.append_instruction("add", mload_2, msize_read)
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)

StoreExpansionPass(ac, fn).run_pass()
CSE(ac, fn).run_pass(1, 5)

assert (
sum(1 for inst in bb.instructions if inst.opcode == "mstore") == 2
), "wrong number of mstores"
assert (
sum(1 for inst in bb.instructions if inst.opcode == "mload") == 2
), "wrong number of mloads"
4 changes: 4 additions & 0 deletions vyper/venom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vyper.venom.ir_node_to_venom import ir_node_to_venom
from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass
from vyper.venom.passes.branch_optimization import BranchOptimizationPass
from vyper.venom.passes.common_subexpression_elimination import CSE
from vyper.venom.passes.dft import DFTPass
from vyper.venom.passes.make_ssa import MakeSSA
from vyper.venom.passes.mem2var import Mem2Var
Expand Down Expand Up @@ -54,9 +55,12 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None:
SimplifyCFGPass(ac, fn).run_pass()
AlgebraicOptimizationPass(ac, fn).run_pass()
BranchOptimizationPass(ac, fn).run_pass()

RemoveUnusedVariablesPass(ac, fn).run_pass()

StoreExpansionPass(ac, fn).run_pass()
CSE(ac, fn).run_pass(2, 10)
RemoveUnusedVariablesPass(ac, fn).run_pass()
DFTPass(ac, fn).run_pass()


Expand Down
251 changes: 251 additions & 0 deletions vyper/venom/analysis/available_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from collections import deque
from dataclasses import dataclass

from vyper.utils import OrderedSet
from vyper.venom.analysis.analysis import IRAnalysesCache, IRAnalysis
from vyper.venom.analysis.cfg import CFGAnalysis
from vyper.venom.analysis.dfg import DFGAnalysis
from vyper.venom.basicblock import (
BB_TERMINATORS,
IRBasicBlock,
IRInstruction,
IROperand,
IRVariable,
)
from vyper.venom.context import IRFunction
from vyper.venom.effects import EMPTY, Effects

_MAX_DEPTH = 5
_MIN_DEPTH = 2


@dataclass
class _Expression:
first_inst: IRInstruction
opcode: str
operands: list["IROperand | _Expression"]

def __eq__(self, other):
if not isinstance(other, _Expression):
return False
return self.first_inst == other.first_inst

def __hash__(self) -> int:
return hash((self.opcode, *self.operands))

def __repr__(self) -> str:
if self.opcode == "store":
assert len(self.operands) == 1, "wrong store"
return repr(self.operands[0])
res = self.opcode + " [ "
for op in self.operands:
res += repr(op) + " "
res += "]"
return res

def same(self, other: "_Expression") -> bool:
if self.opcode != other.opcode:
return False
for self_op, other_op in zip(self.operands, other.operands):
if type(self_op) is not type(other_op):
return False
if isinstance(self_op, _Expression):
assert isinstance(other_op, _Expression)
if not self_op.same(other_op):
return False
else:
assert isinstance(self_op, IROperand)
assert isinstance(other_op, IROperand)
if self_op != other_op:
return False
return True

def contains_expr(self, expr: "_Expression") -> bool:
for op in self.operands:
if op == expr:
return True
if isinstance(op, _Expression) and op.contains_expr(expr):
return True
return False

def get_depth(self) -> int:
max_depth = 0
for op in self.operands:
if isinstance(op, _Expression):
d = op.get_depth()
if d > max_depth:
max_depth = d
return max_depth + 1

def get_reads(self, ignore_msize: bool) -> Effects:
tmp_reads = self.first_inst.get_read_effects()
for op in self.operands:
if isinstance(op, _Expression):
tmp_reads = tmp_reads | op.get_reads(ignore_msize)
if ignore_msize:
tmp_reads &= ~Effects.MSIZE
return tmp_reads

def get_writes(self, ignore_msize: bool) -> Effects:
tmp_reads = self.first_inst.get_write_effects()
for op in self.operands:
if isinstance(op, _Expression):
tmp_reads = tmp_reads | op.get_writes(ignore_msize)
if ignore_msize:
tmp_reads &= ~Effects.MSIZE
return tmp_reads


class _BBLattice:
data: dict[IRInstruction, OrderedSet[_Expression]]
out: OrderedSet[_Expression]
in_cache: OrderedSet[_Expression] | None

def __init__(self, bb: IRBasicBlock):
self.data = dict()
self.out = OrderedSet()
self.in_cache = None
for inst in bb.instructions:
self.data[inst] = OrderedSet()


_UNINTERESTING_OPCODES = ["store", "param", "offset", "phi", "nop"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this up to the top of the file and make public?



class _FunctionLattice:
data: dict[IRBasicBlock, _BBLattice]

def __init__(self, function: IRFunction):
self.data = dict()
for bb in function.get_basic_blocks():
self.data[bb] = _BBLattice(bb)


class AvailableExpressionAnalysis(IRAnalysis):
expressions: OrderedSet[_Expression] = OrderedSet()
inst_to_expr: dict[IRInstruction, _Expression] = dict()
dfg: DFGAnalysis
lattice: _FunctionLattice
min_depth: int
max_depth: int
ignore_msize: bool

def __init__(
self,
analyses_cache: IRAnalysesCache,
function: IRFunction,
min_depth: int = _MIN_DEPTH,
max_depth: int = _MAX_DEPTH,
):
super().__init__(analyses_cache, function)
self.analyses_cache.request_analysis(CFGAnalysis)
dfg = self.analyses_cache.request_analysis(DFGAnalysis)
assert isinstance(dfg, DFGAnalysis)
self.dfg = dfg

self.min_depth = min_depth
self.max_depth = max_depth

self.lattice = _FunctionLattice(function)

self.ignore_msize = not self._contains_msize()

def analyze(self, min_depth: int = _MIN_DEPTH, max_depth: int = _MAX_DEPTH):
self.min_depth = min_depth
self.max_depth = max_depth
worklist: deque = deque()
worklist.append(self.function.entry)
while len(worklist) > 0:
bb: IRBasicBlock = worklist.popleft()
changed = self._handle_bb(bb)

if changed:
for out in bb.cfg_out:
if out not in worklist:
worklist.append(out)

def _contains_msize(self) -> bool:
for bb in self.function.get_basic_blocks():
for inst in bb.instructions:
if inst.opcode == "msize":
return True
return False

def _handle_bb(self, bb: IRBasicBlock) -> bool:
available_expr: OrderedSet[_Expression] = OrderedSet()
if len(bb.cfg_in) > 0:
available_expr = OrderedSet.intersection(
*(self.lattice.data[in_bb].out for in_bb in bb.cfg_in)
)

bb_lat = self.lattice.data[bb]
if bb_lat.in_cache is not None and available_expr == bb_lat.in_cache:
return False
bb_lat.in_cache = available_expr
change = False
for inst in bb.instructions:
if inst.opcode in _UNINTERESTING_OPCODES or inst.opcode in BB_TERMINATORS:
continue
if available_expr != bb_lat.data[inst]:
bb_lat.data[inst] = available_expr.copy()
change |= True

inst_expr = self.get_expression(inst, available_expr)
# write_effects = inst.get_write_effects() # writes.get(inst_expr.opcode, ())
write_effects = inst_expr.get_writes(self.ignore_msize)
for expr in available_expr.copy():
read_effects = expr.get_reads(self.ignore_msize)
if read_effects & write_effects != EMPTY:
available_expr.remove(expr)
continue
write_effects_expr = expr.get_writes(self.ignore_msize)
if write_effects_expr & write_effects != EMPTY:
available_expr.remove(expr)

if (
inst_expr.get_depth() in range(self.min_depth, self.max_depth + 1)
and write_effects & inst_expr.get_reads(self.ignore_msize) == EMPTY
):
available_expr.add(inst_expr)

if available_expr != bb_lat.out:
bb_lat.out = available_expr.copy()
change |= True

return change

def _get_operand(
self, op: IROperand, available_exprs: OrderedSet[_Expression], depth: int
) -> IROperand | _Expression:
if depth > 0 and isinstance(op, IRVariable):
inst = self.dfg.get_producing_instruction(op)
assert inst is not None
if not inst.is_volatile:
return self.get_expression(inst, available_exprs, depth - 1)
return op

def _get_operands(
self, inst: IRInstruction, available_exprs: OrderedSet[_Expression], depth: int
) -> list[IROperand | _Expression]:
return [self._get_operand(op, available_exprs, depth) for op in inst.operands]

def get_expression(
self,
inst: IRInstruction,
available_exprs: OrderedSet[_Expression] | None = None,
depth: int | None = None,
) -> _Expression:
if available_exprs is None:
available_exprs = self.lattice.data[inst.parent].data[inst]
if depth is None:
depth = self.max_depth
operands: list[IROperand | _Expression] = self._get_operands(inst, available_exprs, depth)
expr = _Expression(inst, inst.opcode, operands)
for e in available_exprs:
if expr.same(e):
return e

return expr

def get_available(self, inst: IRInstruction) -> OrderedSet[_Expression]:
return self.lattice.data[inst.parent].data[inst]
Loading
Loading