Skip to content

Commit

Permalink
Add type information to Relay's Python AST (#8)
Browse files Browse the repository at this point in the history
* begins typing expr.py

* minor style fix

* improves typing and adds more

* type -> typ to avoid keyword clash

* reverts prev change

* fixes typo

* merge python/expr conflicts

* incorporate changes

* lint test

* comment out instance var lint test

* undo tests

* adds more types

* more types

* disable specific pylint invalid-name warnings

* minor comments on particular nodes

* minor clean up
  • Loading branch information
joshpoll authored and jroesch committed Aug 16, 2018
1 parent 7bef386 commit 731361f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 38 deletions.
94 changes: 62 additions & 32 deletions relay/python/relay/expr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck
"""All the expression nodes"""
from enum import IntEnum
from typing import Union
from typing import Union, Tuple, Dict
from .base import register_nnvm_node, NodeBase
from . import make

Expand All @@ -13,6 +13,8 @@ class Environment(NodeBase):
"""Global Environment
"""

items: Dict["GlobalId", "Item"]

def add(self, func: "GlobalId") -> None:
return make.Environment_add(self, func)

Expand All @@ -33,6 +35,9 @@ class Item(NodeBase):
"""Base class of all expressions.
"""

#pylint: disable=invalid-name
id: "GlobalId"

def __eq__(self, other):
if type(self) != type(other):
return False
Expand All @@ -53,11 +58,13 @@ class Type(NodeBase):
def __eq__(self, other):
return alpha_eq(self, other)


@register_nnvm_node
class Attributes(NodeBase):
def __getitem__(self, index):
return self.attributes[index]


class Builder(object):
"""A helper class for building partial AST fragments."""

Expand Down Expand Up @@ -128,88 +135,97 @@ def __eq__(self, other):

@register_nnvm_node
class Defn(Item):
pass
type: Type
body: Expr


@register_nnvm_node
class Primitive(Item):
pass
type: Type


@register_nnvm_node
class String(Expr):
pass
value: str


@register_nnvm_node
class IntLit(Expr):
pass
value: int


@register_nnvm_node
class FloatLit(Expr):
pass
value: float


@register_nnvm_node
class BoolLit(Expr):
pass
value: bool


@register_nnvm_node
class TensorLit(Expr):
pass
data: Tuple[NodeBase, ...]


@register_nnvm_node
class ProductLit(Expr):
pass
fields: Tuple[NodeBase, ...]


@register_nnvm_node
class BaseType(Type):
pass
type: str


@register_nnvm_node
class Cast(Expr):
pass
target: Type
node: Expr


@register_nnvm_node
class LocalId(Expr):
name: str

def __hash__(self) -> int:
return hash(self.name)


@register_nnvm_node
class GlobalId(Expr):
pass
name: str


@register_nnvm_node
class IntrinsicId(Expr):
pass
name: str


@register_nnvm_node
class Param(Expr):
pass
#pylint: disable=invalid-name
id: LocalId
type: Type


@register_nnvm_node
class Function(Expr):
pass
params: Tuple[Param, ...]
body: Expr


@register_nnvm_node
class Call(Expr):
pass
#pylint: disable=invalid-name
fn: Expr
args: Tuple[Expr, ...]


@register_nnvm_node
class Debug(Expr):
pass
node: Expr


class UOp(IntEnum):
Expand All @@ -218,11 +234,11 @@ class UOp(IntEnum):

@register_nnvm_node
class UnaryOp(Expr):
pass

#pylint: disable=invalid-name
op: UOp
node: Expr


#pylint: disable=invalid-name
class BOp(IntEnum):
"""The set of builtin binary ops supported by Relay."""
PLUS = 0
Expand All @@ -239,12 +255,16 @@ class BOp(IntEnum):

@register_nnvm_node
class BinaryOp(Expr):
pass
op: BOp
left: Expr
right: Expr


@register_nnvm_node
class Let(Expr):
pass
id: LocalId
value: Expr
body: Expr


@register_nnvm_node
Expand All @@ -254,51 +274,61 @@ class Functor(Expr):

@register_nnvm_node
class Reverse(Expr):
pass
node: Expr


@register_nnvm_node
class Accumulate(Expr):
pass
update_binders: Tuple[LocalId, ...]
value: NodeBase


@register_nnvm_node
class Zero(Expr):
pass
type: Type

# TODO(@jroesch): todo move

# TODO(@jroesch): todo move
# values

@register_nnvm_node
class IntValue(Value):
pass
value: int


@register_nnvm_node
class FloatValue(Value):
pass
value: float


@register_nnvm_node
class BoolValue(Value):
pass
value: bool


@register_nnvm_node
class FnValue(Value):
pass
env: Dict[LocalId, Value]
func: Function

# unsorted

@register_nnvm_node
class If(Expr):
pass
guard: Expr
true_b: Expr
false_b: Expr


@register_nnvm_node
class Shape(NodeBase):
pass


@register_nnvm_node
class TensorType(Type):
pass
dtype: BaseType
shape: Shape


def alpha_eq(left: Union[Expr, Type], right: Union[Expr, Type]) -> bool:
Expand Down
6 changes: 0 additions & 6 deletions relay/python/relay/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def ensure(expected, value, msg):
if value != expected:
raise Exception("{}: found {}".format(msg, value))

# TODO: it would be great to report errors in the Python program


def compile_args_to_params(args):
"""This function will convert Python arguments into Relay parameters."""
ensure([], args.defaults, "relay decorator does not support default arguments")
Expand Down Expand Up @@ -87,7 +84,6 @@ def relay_type_from_annotation(annotation):
#
# Process a single definition and produce a single Relay Defunc.


class DefToRelay(ast.NodeVisitor):
"""Compiles a Python definition to a Relay definition."""
# local_scopes: List[Dict[LocalId, Expr]]
Expand Down Expand Up @@ -242,8 +238,6 @@ def wrapper(*_):

# Store Python line and columb information for errors
#
# Make it possible to handle calls with keywords arguments
#
# Ensure names prefixed with relay. becomes a Intrinsic
#
# Handle assignments
Expand Down

0 comments on commit 731361f

Please sign in to comment.