Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API/Refactor] Unified PackedFunc for API and Generated Functions #26

Merged
merged 1 commit into from
Jan 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ language: cpp

os:
- linux
- osx
# - osx

env:
# code analysis
Expand Down
85 changes: 85 additions & 0 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*!
* Copyright (c) 2016 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_

#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"

namespace tvm {

/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)

private:
/*! \brief name of the function */
std::string name_;
};

/*!
* \brief Get API function by name.
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}

#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_

/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#endif // TVM_API_REGISTRY_H_
74 changes: 7 additions & 67 deletions include/tvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,23 @@
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of TVM DSL
*
* \note The API is designed in a minimum way.
* Most of the API functions are registered and can be pulled out.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef TVM_C_API_H_
#define TVM_C_API_H_

#include "./runtime/c_runtime_api.h"

TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* APIFuncHandle;
/*! \brief handle to node */
typedef void* NodeHandle;

/*!
* \brief List all the node function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMListAPIFuncNames(int *out_size,
const char*** out_array);
/*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
APIFuncHandle *handle);

/*!
* \brief Get the detailed information about function.
* \param handle The operator handle.
* \param real_name The returned name of the function.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);

/*!
* \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty
*
* \param arg The argument
* \param type_code The type_code of argument as in TVMTypeCode
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIPushStack(TVMValue arg,
int type_code);

/*!
* \brief call a function by using arguments in the stack.
* The stack will be cleanup to empty after this call, whether the call is successful.
*
* \param handle The function handle
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code);

/*!
* \brief free the node handle
* \param handle The node handle to be freed.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <algorithm>
#include "./base.h"
#include "./runtime/packed_func.h"

namespace tvm {

Expand Down
196 changes: 196 additions & 0 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*!
* Copyright (c) 2016 by Contributors
* \file packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <type_traits>

#include "./base.h"
#include "./expr.h"

namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};

template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};

template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}

// extensions for tvm arg value

template<typename TNodeRef, typename>
inline TVMArgValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}

inline TVMArgValue::operator Halide::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kInt) {
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kFloat) {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
}
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Expr(sptr);
}

inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<std::shared_ptr<Node> >();
}


template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr =
*ptr<std::shared_ptr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}

// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
return *this;
}

inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
return *this;
}

template<typename TNodeRef, typename>
inline TVMRetValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return TNodeRef(*ptr<std::shared_ptr<Node> >());
}

inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_);
type_codes_[i] = kNodeHandle;
}

// Type related stuffs
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}

inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}

inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
return this->operator=(Type2TVMType(t));
}

inline TVMRetValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}

inline TVMArgValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}

inline void TVMArgsSetter::operator()(
size_t i, const Halide::Type& t) const {
this->operator()(i, Type2TVMType(t));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
Loading