Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Prepare for new plan_devices.cc (part II) #9130

Merged
merged 6 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
return input;
}

/*!
* \brief Copy the function or module, but overrides the attributes with the entries from \p attrs.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attrs Key/values attributes to add to \p input.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with updated attributes.
*/
template <typename TFunc>
inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target";
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The device type which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<Integer> (but interpreted as Array<DLDeviceType>)
*/
constexpr const char* kParamDeviceTypes = "param_device_types";

/*!
* \brief The device type which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Integer (but interpreted as DLDeviceType)
*/
constexpr const char* kResultDeviceType = "result_device_type";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
8 changes: 6 additions & 2 deletions include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \file parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -32,8 +33,11 @@
namespace tvm {
namespace parser {

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module = Optional<IRModule>());
using MetaTable = Map<String, Array<ObjectRef>>;

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

} // namespace parser
} // namespace tvm
Expand Down
41 changes: 34 additions & 7 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,44 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" operator.
* \brief Attributes for the "on_device" special operator.
*
* The relay call
* The Relay call (aka 'annotation'):
* \code
* on_device(expr, device_type=2)
* on_device(sub_expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
* constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2
* (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be
* executed and stored on a different device. If so the compiler will automatically insert a
* "device_copy" call to mediate the transition between devices.
*
* See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level
* companion.
* E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then:
* \code
* multiply(on_device(add(%x, %y), device_type=2), %z)
* \endcode
* indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU.
* The compiler will rewrite this to:
* \code
* multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, device_type=2, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same device. The compiler will check the devices are consistent, and will not insert any
* "device_copy" call. This form of annotation shouldn't be necessary in user programs. However
* it is needed by the \p PlanDevices pass to fully specify the results of device planning so that
* the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), device_type=2, is_fixed=True)
* multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
Expand Down
66 changes: 0 additions & 66 deletions include/tvm/relay/attrs/function.h

This file was deleted.

6 changes: 4 additions & 2 deletions python/tvm/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def add(self, name, content):
return _ffi.get_global_func("SourceMapAdd")(self, name, content)


def parse(source, source_name="from_string"):
return _ffi_api.ParseModule(source_name, source)
def parse(source, source_name="from_string", init_module=None, init_meta_table=None):
if init_meta_table is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be defaulted in the arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. (I had the obvious default but clion was warning about mutable defaults and suggest this instead.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the linter complains about it too "Dangerous default value {} as argument" - I'll go back to the silly form.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't default Python arguments to anything but atomic values. If you use an object or other aggregate data structure the default will be allocated a single time, and only a single time. If you happen to mutate it you will observe the entire history of mutations across all invocations of the function inside the process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks! Not so silly a rule after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest the language behaviour is the silly part? 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest the language behaviour is the silly part? 👍

Yeah its not a good design, but alas you gotta love the Python you are with 😆 it has burned many people many times, super sharp edge.

init_meta_table = {}
return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table)


def parse_expr(source):
Expand Down
6 changes: 3 additions & 3 deletions src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s
}

auto source = (*it).second;
DLOG(INFO) << "Source: " << std::endl << source->source;
VLOG(1) << "Source: " << std::endl << source->source;

DLOG(INFO) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;
VLOG(1) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;

auto line_text = source.GetLine(span->line);

Expand Down
3 changes: 1 addition & 2 deletions src/parser/meta_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_PARSER_META_REF_H_

#include <tvm/ir/attrs.h>
#include <tvm/parser/parser.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>

Expand All @@ -36,8 +37,6 @@ namespace parser {

using namespace relay;

using MetaTable = Map<String, Array<ObjectRef>>;

/*!
* \brief Options for allocating storage.
*/
Expand Down
52 changes: 37 additions & 15 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1092,8 +1092,6 @@ class Parser {

Array<TypeVar> generics;
if (Peek()->token_type == TokenType::kLSquare) {
// If we have generics we need to add a type scope.
PushTypeScope();
generics = ParseSequence<TypeVar>(
TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
auto type_var_name = Match(TokenType::kIdentifier).ToString();
Expand Down Expand Up @@ -1444,6 +1442,10 @@ class Parser {
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
} else {
this->diag_ctx.EmitFatal(Diagnostic::Error(op->span)
<< "unable to determine the 'attrs_type_key' with which "
"to represent the call attributes for this operator");
}
}
return true;
Expand Down Expand Up @@ -1867,7 +1869,7 @@ class Parser {
};

Parser InitParser(const std::string& file_name, const std::string& file_content,
Optional<IRModule> init_module) {
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
SourceName src_name = SourceName::Get(file_name);
Source source(src_name, file_content);
Expand All @@ -1886,42 +1888,62 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,
auto tokens_and_table = Tokenize(diag_ctx, source);

auto tokens = tokens_and_table.first;
auto meta_data_table = tokens_and_table.second;
MetaTable meta_data_table = tokens_and_table.second.ToMetadata();

// Merge any entries in init_meta_table into anything captured in the #[metadata] section
// of the file_content. Metadata references within file_content must use indexes which account
// for this ordering.
for (const auto& pair : init_meta_table) {
Array<ObjectRef> items;
if (meta_data_table.count(pair.first)) {
items = meta_data_table[pair.first];
}
for (const auto& obj : pair.second) {
items.push_back(obj);
}
meta_data_table.Set(pair.first, items);
}
Comment on lines +1896 to +1905
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see any test cases for this logic, is it possible for you to add one so I can see it in action? 😸

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call - done.


return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata());
return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table));
}

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module) {
IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "ParseModule";
auto parser = InitParser(file_name, file_content, init_module);
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
auto infer_type = tvm::relay::transform::InferType();
ICHECK(infer_type.defined()) << "The type inferencer must be non-null.";
return infer_type(mod);
}

Expr ParseExpr(std::string file_name, std::string file_content) {
Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
VLOG(0) << "ParseExpr";
auto parser = InitParser(file_name, file_content, Optional<IRModule>());
auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
parser.ParseSemVer(false);
parser.PushScope();
auto expr = parser.ParseExpr();
parser.Match(TokenType::kEndOfFile);
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
return expr;
}

TVM_REGISTER_GLOBAL("parser.ParseModuleInContext")
.set_body_typed([](const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
return ParseModule(file_name, file_content, init_module, init_meta_table);
});

TVM_REGISTER_GLOBAL("parser.ParseModule")
.set_body_typed([](tvm::String file_name, tvm::String file_content) {
.set_body_typed([](const std::string& file_name, const std::string& file_content) {
return ParseModule(file_name, file_content);
});

Expand Down
6 changes: 3 additions & 3 deletions src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) {
}

tvm::String Source::GetLine(int line) {
DLOG(INFO) << "Source::GetLine: line=" << line;
VLOG(1) << "Source::GetLine: line=" << line;
ICHECK(line - 1 < static_cast<int64_t>((*this)->line_map.size()))
<< "requested line: " << line << "at index: " << (line - 1)
<< "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source;
Expand All @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) {
auto range = (*this)->line_map.at(line - 1);
int line_start = range.first;
int line_length = range.second;
DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
// TODO(@jroesch): expose substring on tvm::String.
auto line_text = std::string((*this)->source).substr(line_start, line_length);
DLOG(INFO) << "Source::GetLine: line_text=" << line_text;
VLOG(1) << "Source::GetLine: line_text=" << line_text;
return line_text;
}

Expand Down
Loading