Skip to content

Commit

Permalink
[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. (apach…
Browse files Browse the repository at this point in the history
…e#3533)

* [INFA][IR] Build and Evolve Low-level IR. Remove dep from HalideIR.


* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <[email protected]>

* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <[email protected]>
  • Loading branch information
tqchen and jroesch authored Jul 11, 2019
1 parent 2d53f84 commit 0218557
Show file tree
Hide file tree
Showing 50 changed files with 4,176 additions and 405 deletions.
21 changes: 4 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
add_definitions(-DHalide_SHARED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
Expand Down Expand Up @@ -112,8 +111,8 @@ else(MSVC)
endif(MSVC)

# add source group
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})
Expand All @@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
src/node/*.cc
src/schedule/*.cc
)

Expand Down Expand Up @@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS
3rdparty/HalideIR/src/base/*.cpp
3rdparty/HalideIR/src/ir/*.cpp
3rdparty/HalideIR/src/tvm/*.cpp
)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})

file(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
Expand Down Expand Up @@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm)
# Related headers
target_include_directories(
tvm
PUBLIC "3rdparty/HalideIR/src"
PUBLIC "topi/include")
target_include_directories(
tvm_topi
Expand Down Expand Up @@ -294,11 +288,6 @@ if (INSTALL_DEV)
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
FILES_MATCHING
PATTERN "*.h"
)
install(
DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
FILES_MATCHING
Expand All @@ -319,8 +308,6 @@ endif(INSTALL_DEV)

# More target definitions
if(MSVC)
target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
Expand Down
4 changes: 2 additions & 2 deletions apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ inline TNodeRef NullValue() {
}

template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
}

/*! \brief Error thrown during attribute checking. */
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class Layout : public NodeRef {
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
Expand All @@ -243,7 +243,7 @@ class Layout : public NodeRef {
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
if (var->var.get()->name_hint == axis.name()) {
if (var->var->name_hint == axis.name()) {
return true;
}
}
Expand Down
246 changes: 246 additions & 0 deletions include/tvm/dtype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file tvm/dtype.h
* \brief Data type used in IR.
*/
#ifndef TVM_DTYPE_H_
#define TVM_DTYPE_H_

#include "runtime/packed_func.h"

namespace tvm {
class Expr;

/*!
* \brief Primitive data types in tvm.
*/
class DataType {
public:
/*! \brief default constructor */
DataType() {}
/*!
* \brief Constructor
* \param dtype The DLDataType
*/
explicit DataType(DLDataType dtype)
: data_(dtype) {}
/*!
* \brief Constructor
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
*/
DataType(int code, int bits, int lanes) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
data_.lanes = static_cast<uint16_t>(lanes);
}
/*! \return The type code. */
int code() const {
return static_cast<int>(data_.code);
}
/*! \return number of bits in the data. */
int bits() const {
return static_cast<int>(data_.bits);
}
/*! \return number of bytes to store each scalar. */
int bytes() const {
return (bits() + 7) / 8;
}
/*! \return number of lanes in the data. */
int lanes() const {
return static_cast<int>(data_.lanes);
}
/*! \return whether type is a scalar type. */
bool is_scalar() const {
return lanes() == 1;
}
/*! \return whether type is a scalar type. */
bool is_bool() const {
return code() == kDLUInt && bits() == 1;
}
/*! \return whether type is a float type. */
bool is_float() const {
return code() == kDLFloat;
}
/*! \return whether type is an int type. */
bool is_int() const {
return code() == kDLInt;
}
/*! \return whether type is an uint type. */
bool is_uint() const {
return code() == kDLUInt;
}
/*! \return whether type is a handle type. */
bool is_handle() const {
return code() == kHandle;
}
/*! \return whether type is a vector type. */
bool is_vector() const {
return lanes() > 1;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
* \return the result type.
*/
DataType with_lanes(int lanes) const {
return DataType(data_.code, data_.bits, lanes);
}
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
* \return the result type.
*/
DataType with_bits(int bits) const {
return DataType(data_.code, bits, data_.lanes);
}
/*!
* \brief Get the scalar version of the type.
* \return the result type.
*/
DataType element_of() const {
return with_lanes(1);
}
// operator overloadings
bool operator==(const DataType& other) const {
return
data_.code == other.data_.code &&
data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
}
bool operator!=(const DataType& other) const {
return !operator==(other);
}
operator DLDataType () const {
return data_;
}
/*! \return the maximum possible value in this format. */
TVM_DLL Expr max() const;
/*! \return the minimum possible value in this format. */
TVM_DLL Expr min() const;

private:
DLDataType data_;
};

/*!
* \brief Construct an int type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
* \return The constructed data type.
*/
inline DataType Int(int bits, int lanes = 1) {
return DataType(kDLInt, bits, lanes);
}

/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType UInt(int bits, int lanes = 1) {
return DataType(kDLUInt, bits, lanes);
}

/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Bool(int lanes = 1) {
return UInt(1, lanes);
}

/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Float(int bits, int lanes = 1) {
return DataType(kDLFloat, bits, lanes);
}

/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
inline DataType Handle(int bits = 64, int lanes = 1) {
return DataType(kHandle, bits, lanes);
}

/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
*/
inline DataType TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}

/*!
* \brief Convert DLDataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DataType TVMType2Type(DLDataType t) {
return DataType(t.code, t.bits, t.lanes);
}

/*!
* \brief Convert DataType to DataType.
* \param t The original type.
* \return The conversion result.
*/
inline DLDataType Type2TVMType(DataType t) {
return t.operator DLDataType();
}

/*!
* \brief Get the number of bytes needed in a vector.
* \param dtype The data type.
* \return Number of bytes needed.
*/
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}

// Overload print function.
inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
using namespace tvm::runtime;
return os << dtype.operator DLDataType();
}

// Backward compatibility
using Type = DataType;
} // namespace tvm
#endif // TVM_DTYPE_H_
Loading

0 comments on commit 0218557

Please sign in to comment.