Skip to content

Commit

Permalink
Two-Stage Frontend Compiler (#9)
Browse files Browse the repository at this point in the history
* Two-Stage Frontend Compiler (WIP)

* Fixes some bugs and changes tests

* Changes to fit in with the new repo

* Fix linting issues

* Remove attempt to do eager type inference in frontend

* Repair tests for MiniPy

* Fix linting issues

* Use Pytest over nose

* Add pytest to packages
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 6866e60 commit 85009f7
Show file tree
Hide file tree
Showing 10 changed files with 1,668 additions and 2 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ pylint:
jnilint:
python3 dmlc-core/scripts/lint.py tvm4j-jni cpp jvm/native/src

lint: cpplint pylint jnilint
mypy:
python3.6 -m mypy --ignore-missing-imports relay/python/relay relay/tests/python

lint: cpplint pylint jnilint mypy

doc:
doxygen docs/Doxyfile
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions relay/python/relay/frontend/mini_python/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# restore at some point
# pylint: disable-all
"""A decorator for rewriting Python code into Relay."""
import inspect
from typing import Dict, List, Tuple
from collections import OrderedDict
import typed_ast.ast3 as ast
import numpy as np
import tvm
from relay.typing import Int, UInt, Float, Bool
import relay.eval as re
import relay.ir as ir
from relay.make import *
from relay.ir import BOp, UOp, Value
from relay.operators import __relay_tvm_context__
from .py_to_mini_py import compile_func
from .mini_py_to_relay import compile_global_defn

def relay_compile(f):
mini_py_ast = compile_func(f)
relay_ast = compile_global_defn(mini_py_ast)
return relay_ast

def marshal_argument(arg, _) -> ir.Value:
"""Convert Python values into the appropriate types
for the Relay evaluator.
"""
if isinstance(arg, int):
return IntValue(arg)
elif isinstance(arg, float):
return FloatValue(arg)
elif isinstance(arg, np.ndarray):
tvm_array = tvm.nd.array(arg.astype(np.float32), __relay_tvm_context__)
return ir.TensorValue.from_ndarray(tvm_array)
elif isinstance(arg, Value):
return arg
elif isinstance(arg, tvm.ndarray.NDArray):
return ir.TensorValue.from_ndarray(arg)
else:
raise Exception(f"unsupported argument type {type(arg)}")

# TODO(@weberlo): Replace current decorator usage with this one.
def relay(func):
"""The Relay decorator.
Retrieves the source code of `func`, compiles it to Relay, adds it to
Relay's global environment. When the decorated function is called, it will
use Relay's evaluator.
"""
env = get_env()
try:
defn = compile_func(func)
env.add(defn)
except FrontendError as e:
get_env().display_errors()

def wrapper(*py_args):
# TODO: Check types of Python args against args in the Relay definition.
args = []
assert len(py_args) == len(defn.body.params)
for arg, param in zip(py_args, defn.body.params):
args.append(marshal_argument(arg, param.type))
return re.invoke(env, defn.id, args)

return wrapper
120 changes: 120 additions & 0 deletions relay/python/relay/frontend/mini_python/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""A small intermediate imperative AST.
Primarily for doing transformations before converting to Relay.
"""
from typing import List, Tuple, Optional
import relay.ir as relay


class Stmt:
span: relay.Span

def set_span(self, span: relay.Span):
self.span = span


class LocalDefn(Stmt):
"""Local Function Definition"""
ident: relay.LocalId
typ: relay.Type
body: "Function"

def __init__(
self,
ident: relay.LocalId,
typ: relay.Type,
body: "Function") -> None:
self.ident = ident
self.typ = typ
self.body = body


class GlobalDefn(Stmt):
"""Global Function Definition"""
ident: relay.GlobalId
typ: relay.Type
body: "Function"

def __init__(
self,
ident: relay.GlobalId,
typ: relay.Type,
body: "Function") -> None:
self.ident = ident
self.typ = typ
self.body = body

# pylint: disable=missing-docstring


class Assign(Stmt):
ident: relay.LocalId
typ: relay.Type
expr: relay.Expr

def __init__(
self,
ident: relay.LocalId,
typ: relay.Type,
expr: relay.Expr) -> None:
self.ident = ident
self.typ = typ
self.expr = expr


class Return(Stmt):
expr: relay.Expr

def __init__(
self,
expr: relay.Expr) -> None:
self.expr = expr

# pylint: disable=missing-docstring


class Function:
params: Tuple[relay.Param]
ret_typ: relay.Type
body: List[Stmt]
span: relay.Span

def __init__(
self,
params: Tuple[relay.Param],
ret_typ: relay.Type,
body: List[Stmt], span: Optional[relay.Span] = None) -> None:
self.params = params
self.ret_typ = ret_typ
self.body = body
self.span = span

def set_span(self, span: relay.Span):
self.span = span


class If(Stmt):
guard: relay.Expr
true_body: List[Stmt]
false_body: List[Stmt]

def __init__(
self,
guard: relay.Expr,
true_body: List[Stmt],
false_body: List[Stmt]) -> None:
self.guard = guard
self.true_body = true_body
self.false_body = false_body


class While(Stmt):
guard: relay.Expr
body: List[Stmt]

def __init__(
self,
guard: relay.Expr,
body: List[Stmt]) -> None:
self.guard = guard
self.body = body
Loading

0 comments on commit 85009f7

Please sign in to comment.