Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Include buffer semnatics, introduce pylint #11

Merged
merged 3 commits into from
Jan 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)

lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src
python2 dmlc-core/scripts/lint.py tvm all include src python

doc:
doxygen docs/Doxyfile
Expand Down
2 changes: 1 addition & 1 deletion dmlc-core
Submodule dmlc-core updated 1 files
+1 −1 include/dmlc/json.h
98 changes: 98 additions & 0 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

/*!
* Copyright (c) 2016 by Contributors
* \file buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer.
*/
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_

#include <tvm/container.h>
#include <string>

#include "./base.h"
#include "./expr.h"

namespace tvm {

// Internal node container Buffer
class BufferNode;
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program.
*/
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
};

/*! \brief Node to represent a buffer */
class BufferNode : public Node {
public:
/*! \brief optional name of the buffer */
std::string name;
/*! \brief The pointer to the head of the data */
Var ptr;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
// Maybe need more information(alignment) later
/*! \brief constructor */
BufferNode() {}

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("ptr", &ptr);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
}

static Buffer make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype);

static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
};

inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_BUFFER_H_
19 changes: 16 additions & 3 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"

namespace tvm {
Expand Down Expand Up @@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt);
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(FunctionRef f,
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
Expr body);


/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

} // namespace ir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=redefined-builtin, wildcard-import
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from ._ctypes._api import register_node
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# coding: utf-8
# pylint: disable=invalid-name
# pylint: disable=invalid-name, no-member
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import

import sys
import os
import ctypes
import numpy as np
from . import libinfo
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs

Expand All @@ -14,6 +15,7 @@
from .. import _function_internal

class ArgVariant(ctypes.Union):
"""ArgVariant in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
Expand All @@ -30,8 +32,8 @@ class ArgVariant(ctypes.Union):

def _return_node(x):
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
Expand All @@ -47,7 +49,7 @@ def _return_node(x):
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: _return_node(x)
kNodeHandle: _return_node
}

class SliceBase(object):
Expand Down Expand Up @@ -251,6 +253,7 @@ def register_node(type_key=None):
"""
if isinstance(type_key, str):
def register(cls):
"""internal register function"""
NODE_TYPE[type_key] = cls
return cls
return register
Expand All @@ -273,9 +276,9 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace]
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}

for name in op_names:
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
Expand All @@ -6,6 +7,7 @@

@register_node
class Array(NodeBase):
"""Array container of TVM"""
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
Expand All @@ -19,13 +21,15 @@ def __repr__(self):

@register_node
class Map(NodeBase):
"""Map container of TVM"""
def __getitem__(self, k):
return _function_internal._MapGetItem(self, k)

def __contains__(self, k):
return _function_internal._MapCount(self, k) != 0

def items(self):
"""Get the items from the map"""
akvs = _function_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

Expand All @@ -38,9 +42,17 @@ def __repr__(self):

@register_node
class Range(NodeBase):
"""Represent range in TVM"""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable."""
pass


@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
3 changes: 2 additions & 1 deletion python/tvm/expr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
Expand Down Expand Up @@ -174,7 +175,7 @@ class Call(Expr):
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
pass


@register_node
class Let(Expr):
Expand Down
57 changes: 49 additions & 8 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert
from . import _function_internal
from . import make as _make
Expand All @@ -8,6 +11,7 @@

int32 = "int32"
float32 = "float32"
handle = "handle"

def const(value, dtype=None):
"""construct a constant"""
Expand Down Expand Up @@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)


def placeholder(shape, dtype = None, name="placeholder"):
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.

Parameters
Expand All @@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"):
tensor: tensor.Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder(
shape, dtype, name)
Expand Down Expand Up @@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"):
tensor: tensor.Tensor
The created tensor
"""
if isinstance(shape, _expr.Expr):
shape = (shape, )
shape = (shape,) if isinstance(shape, _expr.Expr) else shape

ndim = len(shape)
arg_names = fcompute.__code__.co_varnames
Expand All @@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"):
op_node = _function_internal._ComputeOp(
name, dim_var, body)
return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0)
shape, body.dtype, op_node, 0)


def Buffer(shape, dtype=None,
name="buffer", ptr=None,
strides=None):
"""Create a new buffer

Parameters
----------
shape : tuple of Expr
The shape of the buffer.

dtype : str, optional
The data type of the buffer.

name : str, optional
The name of the buffer.

ptr : Var, optional
The data pointer in the buffer.

strides: array of Expr
The stride of the buffer.

Returns
-------
buffer : Buffer
The created buffer
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if ptr is None:
ptr = Var(name, "handle")

return _function_internal._Buffer(
name, ptr, shape, strides, dtype)


def IterVar(dom, name='iter', thread_tag=''):
Expand Down Expand Up @@ -170,7 +211,7 @@ def sum(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
x = _make.Reduce("Add", expr, rdom)
return x


Expand All @@ -186,7 +227,7 @@ def min(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
x = _make.Reduce("Min", expr, rdom)
return x


Expand All @@ -202,7 +243,7 @@ def max(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
x = _make.Reduce("Max", expr, rdom)
return x


Expand Down
Loading