Skip to content

Commit

Permalink
[Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass (a…
Browse files Browse the repository at this point in the history
…pache#8069)

* Initial skeleton for fp16 pass.

initial green gray and red lists

move fp16 conversion to own fodler

second pass example

split up files a bit more

cool nodes bro

initial transofmr pass

* Working python version of fp16 pass.

fix topi conv2d not casting kernel to output type

working resnet, but conv2d topi intrinsics need work

tests for resnet

add more tests, extend coverage for converter

update tests, ensure red ops convert back to fp32

clean up code a bit

simplify fp16 output dtype examination

fix pass

update tests

initial coloring

* Rewrite python passes in C++

inspect arg fields

add propagate colors pass"

private -> public inheritance"

rewrite draft

full transformation in c++

remove prints

fp16 pass the proper wrapping

insert extra cast to pass type checking

fix previously broken test by removing cast in wrong scenario

remove old python_files

* Extend support to things besides CallNodes. E.g. tuples and lets

fp32 invalidate typing instead of cast adding

basic tests

skeleton code out

Stash work -- casting based on checked types

working let statements

add more ops, handle functions more generally

add multiply, fix broken case

support TupleNodes properly, move hash function for datatypes into data_type.h"

update simple let test with structural expectation

cleanup p1

remove old file

* Rewrite how and when casting is done by checking types directly.

add support for GPT2, BERT

add some more comments

new single pass version

formatting

make a lot of things const references

clean up tests

more cleanup

more comments

final comment

add newline

* linting and formatting

* add AST header

* remove todo

* lint errors2

* remove i386 incompatible features

* Trigger CI again

* set seed

* lint

* address animesh's initial comments

* mutate attributes only if they were originally floats

* initial comments from matthew

* add comment on hashing strat

* add missing ;

* edge case when mutating attrs

* Cody's easy to address comments

* add test to show green-red casting works

* remove np.random seed from each test

* remove as many references to fp16 types in favor of generic mixed types

* rename RED, GREEN, GRAY to MIXED_PRECISION_ALLOW, etc.

* skeleton for supporting arbitrary mixed types

* cool tests

* Using MixedModeMutator

* rename things ToMixedPrecision

* rename passes to amp.cc

* rename tests to match transform

* clean up typos

* rename even better to_mixed_precision

* don't insert into cache when dtypes equal

* new python interface for registering ops

* cleaner registering ops

* add fp64 structural test

* clean up and comments

* make copy of attributes

* asf header

* pylint

* remove TODO which is solved

* Apply nits from code review (comaniac)

Co-authored-by: Cody Yu <[email protected]>

* change cast_node_cache --> cast_node_cache_

* add check for returned vals

* better error msg

* docstring for pass in python

* fix default behavior to be proper

* better error reporting via single flag

* priority to 0

* address more nits

* fix story telling slightly

* restart

* correct docstring

* change class fields to have _ at end

* add class docstring

* add comment on accumulation dtype hack

* ADT warnings

* add todo

* fix linter

Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent b10c3bb commit 1876964
Show file tree
Hide file tree
Showing 9 changed files with 1,189 additions and 21 deletions.
15 changes: 15 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,4 +389,19 @@ inline DLDataType String2DLDataType(std::string s) {
using DataType = runtime::DataType;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::DataType> {
inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; }
std::size_t operator()(tvm::DataType const& dtype) const {
int a = dtype.code();
int b = dtype.bits();
int c = dtype.lanes();
int d = cantor_pairing_function(a, b);
return cantor_pairing_function(c, d);
}
};
} // namespace std

#endif // TVM_RUNTIME_DATA_TYPE_H_
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
debug,
register_external_compiler,
register_fake_quantization_to_integer,
register_mixed_precision_conversion,
)
from . import strategy

Expand Down
33 changes: 30 additions & 3 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
"""The base node types for the Relay language."""
import tvm._ffi
import tvm.ir
from tvm.driver import lower, build
from tvm.target import get_native_generic_func, GenericFunc
from tvm.runtime import Object
import tvm.ir._ffi_api
from tvm.driver import build, lower
from tvm.runtime import Object
from tvm.target import GenericFunc, get_native_generic_func

from . import _make


Expand Down Expand Up @@ -457,6 +458,32 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10):
return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)


def register_mixed_precision_conversion(op_name, func=None, level=10):
"""Register mixed precision conversion function for an op
Given an op the function should return information on how the value should be
converted. Specifically the function should take a call node and the target
mixed precision datatype (e.g. FP16) and return the conversion category
(see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation
and output datatype of the operation in the mixed precision dtype space.
Parameters
----------
op_name : str
The name of the operator
func: function (call_node: relay.Call, target_dtype: string)
-> [conversion category, accumulation dtype, output dtype]: [int, string, string]
A function which given a call_node and target_dtype (e.g. FP16) returns the
conversion category and associated accumulation/output of the operation
when transformed into the mixed precision dtype space.
level : int
The priority level
"""
return tvm.ir.register_op_attr(op_name, "FTVMMixedPrecisionConversionType", func, level)


@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
Expand Down
195 changes: 195 additions & 0 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long,unused-argument
"""Default behavior for ops in mixed_precision pass. Import this file to use."""
from typing import List

from tvm import relay
from tvm.relay.op import register_mixed_precision_conversion

# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
# numerical reasons.
MIXED_PRECISION_ALWAYS = 0
MIXED_PRECISION_FOLLOW = 1
MIXED_PRECISION_NEVER = 2

# Default lists inspired from TF's classifications:
# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
DEFAULT_ALWAYS_LIST = [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv1d_transpose",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
"nn.dense",
# "nn.batch_matmul", # Handled by a special case
]
DEFAULT_FOLLOW_LIST = [
# These ops add new data or change shape
"nn.pad",
"nn.batch_flatten",
"concatenate",
"zeros",
"split",
"squeeze",
"transpose",
"expand_dims",
"reshape",
"dyn.reshape",
"broadcast_to_like",
"dyn.broadcast_to",
"strided_slice",
"dyn.strided_slice",
"take",
"argwhere",
"where",
"tile",
"dyn.tile",
"scatter",
"full",
"dyn.full",
# Comparison
"less",
"greater",
"less_equal",
"greater_equal",
# By definition copy and cast will depend on inputs for output.
"copy",
"cast",
"cast_like",
# Simple arithmetic
"add",
"subtract",
"multiply",
"divide",
"nn.bias_add",
"nn.batch_norm",
"sum",
"mean",
"sqrt",
"shape_of",
# Simple activations
"max",
"min",
"maximum",
"minimum",
"nn.relu",
"nn.leaky_relu",
"nn.prelu",
"nn.dropout",
# Complicated activations which saturate in a narrow range
"sigmoid",
"tanh",
# Pooling operations
"nn.max_pool1d",
"nn.max_pool2d",
"nn.max_pool3d",
"nn.avg_pool1d",
"nn.avg_pool2d",
"nn.avg_pool3d",
# "nn.global_max_pool1d", # does not exist yet
"nn.global_max_pool2d",
# "nn.global_max_pool3d", # does not exist yet
# "nn.global_avg_pool1d", # does not exist yet
"nn.global_avg_pool2d",
# "nn.global_avg_pool3d", # does not exist yet
"nn.adaptive_max_pool1d",
"nn.adaptive_max_pool2d",
"nn.adaptive_max_pool3d",
"nn.adaptive_avg_pool1d",
"nn.adaptive_avg_pool2d",
"nn.adaptive_avg_pool3d",
]
DEFAULT_NEVER_LIST = [
# In general if |f(x)| >> |x| for expected inputs then put the op here.
"exp",
"power",
"nn.cross_entropy",
"nn.cross_entropy_with_logits",
"nn.softmax",
"nn.l2_normalize",
# Error function doesn't seem to be able to be lowered into fp16 version in llvm.
# Move to follow list when it does.
"erf",
]


# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
def register_func_to_op_list(list_ops: List):
def decorator(func):
for op_name in list_ops:
register_mixed_precision_conversion(op_name, func=func)

return decorator


def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
"""A function which returns output dtypes in a way which works for most ops.
Parameters
---------
call_node: relay.Call
The call node containing the op.
mixed_precision_type: str
The target type to run the operation in.
Returns
-------
output_dtypes : [str, str]
A list of two strings. The first represents the datatype used for accumulation
in the operation. The second represents the actual output datatype.
"""
# Assume support accumulation dtypes <---> has out_dtype attr.
# This is because there is no better way right now to tell which ops support accumulating
# at different data types.
# Some discussion here about making this better is here:
# https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo
if hasattr(call_node.attrs, "out_dtype"):
return ["float32", mixed_precision_type]

# [accumulation_dtype, output_dtype] for the operations
return [mixed_precision_type, mixed_precision_type]


# Functions for FTVMMixedPrecisionConversionType which
# Take in CallNodes and a DType and returns a conversion type,
# an accumulation dtype, and an output_dtype.
@register_func_to_op_list(list_ops=DEFAULT_ALWAYS_LIST)
def generic_always_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_ALWAYS] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_FOLLOW_LIST)
def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_FOLLOW] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_mixed_precision_conversion("nn.batch_matmul")
def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List:
# TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
# Batched matmul has inconsistent support for mixed precision operations.
# Many schedules ignore the out_dtype attribute which leads to errors when
# input types do not match the out_dtype. Therefore, accumulate to output_dtype.
return [MIXED_PRECISION_ALWAYS, "float16", "float16"]
35 changes: 30 additions & 5 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
"""
Relay pass transformation infrastructure.
"""
import types
import inspect
import functools
import inspect
import types
import warnings

import tvm.ir
from tvm import te
from tvm import relay, te
from tvm.runtime import ndarray as _nd

from tvm import relay
from . import _ffi_api


Expand Down Expand Up @@ -1168,7 +1167,7 @@ def AnnotateSpans():
Returns
-------
ret : tvm.transform.Pass
The regsistered AnnotateSpans pass.
The registered AnnotateSpans pass.
"""
return _ffi_api.AnnotateSpans()

Expand Down Expand Up @@ -1199,3 +1198,29 @@ def FakeQuantizationToInteger():
The registered SimplifyExpr pass.
"""
return _ffi_api.FakeQuantizationToInteger()


def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
"""
Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
where as many operations as possible are in the target mixed_precision_type.
Parameters
----------
mixed_precision_type: str
The target datatype to transform operations in the graph to use.
missing_op_mode: int
Determines how to handle ops not registered with FTVMMixedPrecisionConversionType
0: Does not allow any missing ops. Will throw errors when encountering any.
1: Allow missing ops but emit warnings.
2: Allow missing ops and silently ignore them.
Returns
-------
ret : tvm.transform.Pass
The registered pass.
"""
if missing_op_mode < 0 or missing_op_mode > 2:
raise ValueError("Missing op mode is either 0, 1, or 2")
return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode)
10 changes: 7 additions & 3 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
# pylint: disable=unused-argument, redefined-builtin
"""Conv2D operators"""
from __future__ import absolute_import as _abs

from collections import namedtuple

import tvm
from tvm import te, auto_scheduler
from tvm import auto_scheduler, te

from ..utils import get_const_int, get_const_tuple, simplify, tag
from .pad import pad
from .utils import get_pad_tuple
from ..utils import simplify, get_const_tuple, get_const_int, tag
from .winograd_util import winograd_transform_matrices

# workload description of conv2d
Expand Down Expand Up @@ -548,7 +550,9 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
ow * WSTR + kw * dilation_w,
idxmod(ic, ic_bn),
].astype(out_dtype)
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block],
* kernel[oc_chunk, idxdiv(ic, ic_bn), kh, kw, idxmod(ic, ic_bn), oc_block].astype(
out_dtype
),
axis=[ic, kh, kw],
),
name="conv2d_NCHWc",
Expand Down
Loading

0 comments on commit 1876964

Please sign in to comment.