Skip to content

Commit

Permalink
[Torch] Support returning quantized weights and bias for BYOC use cas…
Browse files Browse the repository at this point in the history
…es (apache#9135)

* [Torch] refactored the way is bias quantization done

* support returning 8bit weight

* add test

* add doc

* pylint

* return_int8_weight -> keep_quantized_weight

* fixed for dynamic linear case

* remove test function call

* simplifying
  • Loading branch information
masahi authored and ylc committed Sep 29, 2021
1 parent 2d37f2c commit 9eea017
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 33 deletions.
25 changes: 22 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3713,6 +3713,7 @@ def from_pytorch(
custom_convert_map=None,
default_dtype="float32",
use_parser_friendly_name=False,
keep_quantized_weight=False,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand Down Expand Up @@ -3745,6 +3746,16 @@ def from_pytorch(
so a variable name like "dense.weight" cannot be parsed correctly.
Use this option when you want to run the AnnotateSpans pass on the imported module.
keep_quantized_weight : bool
Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights
in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use
a PyTorch function to unpack quantized weights into float32 arrays and quantization
parameters. By default, we return float32 weights and rely on the QNN lowering and the
Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however,
we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True,
we quantize weights in the frontend using a function that is equivalent to
qnn.op.quantize(...) operating on Numpy arrays.
Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -3789,9 +3800,17 @@ def from_pytorch(
# For quantized models
quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
if len(quantized_ops.intersection(set(op_names))) > 0:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params)
weight_quant_params = qnn_torch.get_weight_quant_params(
script_module, packed_param_map.values()
)
input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(
outputs,
packed_param_map,
weight_quant_params,
input_scales_for_bias,
keep_quantized_weight,
)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
converter.update_convert_map(qnn_torch.convert_map)

Expand Down
115 changes: 87 additions & 28 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@
class QNNParam:
"""A placeholder for weight quantization parameters"""

def __init__(self, weight, bias, scale, zero_point, param_key):
param_prefix = param_key[: -len("._packed_params")]
self.weight_var = _expr.var(param_prefix + "_weight", shape=weight.shape)
def __init__(self, weight, bias, scale, zero_point):
self.weight = weight

if bias is not None:
self.bias_var = _expr.var(param_prefix + "_bias", shape=bias.shape)
self.bias = bias.detach().numpy()
else:
self.bias_var = None
self.bias = None

self.scale = _expr.const(scale)
Expand All @@ -55,10 +51,8 @@ class ConvPackedParam(QNNParam):
together with weights and quantization parameters
"""

def __init__(
self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
):
super().__init__(weight_np, bias, scale, zero_point, param_name)
def __init__(self, weight_np, bias, scale, zero_point, stride, padding, dilation, groups):
super().__init__(weight_np, bias, scale, zero_point)
self.stride = stride
self.padding = padding
self.dilation = dilation
Expand All @@ -81,23 +75,21 @@ def _get_quant_params(qweight):
return weight_np, scales, 0


def make_qnn_param(param_name, qweight, bias):
def make_qnn_param(qweight, bias):
weight_np, scale, zero_point = _get_quant_params(qweight)
return QNNParam(weight_np, bias, scale, zero_point, param_name)
return QNNParam(weight_np, bias, scale, zero_point)


def make_conv_packed_param(param_name, qweight, bias, packed_params):
def make_conv_packed_param(qweight, bias, packed_params):
weight_np, scale, zero_point = _get_quant_params(qweight)
stride = packed_params.stride()
padding = packed_params.padding()
dilation = packed_params.dilation()
groups = packed_params.groups()
return ConvPackedParam(
weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups
)
return ConvPackedParam(weight_np, bias, scale, zero_point, stride, padding, dilation, groups)


def get_weight_quant_params(script_module):
def get_weight_quant_params(script_module, packed_param_names):
"""Retrive and unpack weight parameters from quantized modules"""
import torch

Expand All @@ -114,6 +106,9 @@ def filter_func(named_module):
key = name + "." + param_name
state_dict = m.state_dict()

if key not in packed_param_names:
continue

if len(state_dict) == 0 and not hasattr(m, param_name):
# for v1.6 and above
# This case seems to happen if a model is serialized
Expand All @@ -130,28 +125,87 @@ def filter_func(named_module):

if "Conv" in m.original_name and len(state_dict) == 0:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params)
quant_params[key] = make_conv_packed_param(qweight, bias, packed_params)
elif "Conv" in m.original_name:
qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)
quant_params[key] = make_qnn_param(qweight, bias)
elif m.original_name == "LinearPackedParams":
qweight, bias = torch.ops.quantized.linear_unpack(packed_params)
quant_params[key] = make_qnn_param(key, qweight, bias)
quant_params[key] = make_qnn_param(qweight, bias)

return quant_params


def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
def quantize_numpy(weight, scale, zero_point, out_dtype_np):
iinfo = np.iinfo(out_dtype_np)
clip_min = iinfo.min
clip_max = iinfo.max
if len(scale.shape) > 0:
scale = np.reshape(scale, [weight.shape[0]] + [1] * (len(weight.shape) - 1))
transformed = zero_point + weight / scale
return np.clip(np.round(transformed), clip_min, clip_max).astype(out_dtype_np)


def add_quant_params_to_outputs(
outputs, packed_param_map, quant_params, input_scales_for_bias, keep_quantized_weight=False
):
"""
Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here.
"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
qweight = relay.qnn.op.quantize(
qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0
)
params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var]
weight_scale = _get_numpy(qparam.scale)
param_prefix = packed_param_name[: -len("._packed_params")]

if keep_quantized_weight:
qparam.weight_var = _expr.var(
param_prefix + "_weight", shape=qparam.weight.shape, dtype="int8"
)
qparam.weight = quantize_numpy(
qparam.weight, weight_scale, _get_numpy(qparam.zero_point), np.int8
)
qweight = qparam.weight_var
else:
qparam.weight_var = _expr.var(
param_prefix + "_weight", shape=qparam.weight.shape, dtype="float32"
)
qweight = relay.qnn.op.quantize(
qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0
)

if qparam.bias is not None:
float_bias_var = _expr.var(
param_prefix + "_bias", shape=qparam.bias.shape, dtype="float32"
)
if node_name not in input_scales_for_bias:
# This case is for dynamic quantization, where the input activation scale is
# unknown until runtime.
qparam.bias_var = float_bias_var
qbias = qparam.bias_var
elif keep_quantized_weight:
qparam.bias_var = _expr.var(
param_prefix + "_bias", shape=qparam.bias.shape, dtype="int32"
)
qparam.bias = quantize_numpy(
qparam.bias, input_scales_for_bias[node_name] * weight_scale, 0, np.int32
)
qbias = qparam.bias_var
else:
qparam.bias_var = float_bias_var
qbias = relay.qnn.op.quantize(
qparam.bias_var,
_expr.const(input_scales_for_bias[node_name] * weight_scale),
_expr.const(0, "int32"),
out_dtype="int32",
axis=0,
)
else:
qbias = None

quant_params[packed_param_name] = qparam

params = [qweight, qparam.scale, qparam.zero_point, qbias]

if isinstance(quant_params[packed_param_name], ConvPackedParam):
params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups]
Expand Down Expand Up @@ -367,6 +421,8 @@ def add_input_quant_params_to_op_inputs(graph):
need_input_quant_param = set(num_quantized_inputs.keys())
need_input_quant_param.add("quantized::cat")

input_scales_for_bias = {}

for node in graph.nodes():
operator = node.kind()
if operator not in need_input_quant_param:
Expand Down Expand Up @@ -401,6 +457,12 @@ def add_input_quant_params_to_op_inputs(graph):
node.addInput(scale)
node.addInput(zp)

if "conv2d" in operator or "linear" in operator:
# This is required for quantizing the bias
input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")

return input_scales_for_bias


def add_quant_params(params, quant_params):
"""Add quant parameters to TVM param map"""
Expand Down Expand Up @@ -478,10 +540,7 @@ def _do_bias_and_requantize(
# Instead, the torch way requires rounding of activation at runtime

if bias is not None:
qbias = relay.qnn.op.quantize(
bias, requant_input_scale, _expr.const(0, "int32"), out_dtype="int32", axis=0
)
requantize_input = _op.nn.bias_add(output, qbias)
requantize_input = _op.nn.bias_add(output, bias)
else:
requantize_input = output

Expand Down
43 changes: 41 additions & 2 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ def torch_version_check():
return version.parse(torch.__version__) > version.parse("1.4.0")


def get_tvm_runtime(script_module, input_name, ishape):
def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
mod, params = relay.frontend.from_pytorch(
script_module, input_shapes, keep_quantized_weight=keep_quantized_weight
)

if keep_quantized_weight:
for p in params.values():
assert p.dtype in ["int8", "int32"]

with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
Expand Down Expand Up @@ -609,3 +615,36 @@ def test_qnn_mergecomposite():

input_name = "image"
run_qnn_mergecomposite(script_module, input_name, inp.shape)


def test_keep_quantized_weight():
qmodules = []

for per_channel in [False, True]:
qmodules += [
((1, 3, 224, 224), ConvBn(), per_channel),
((16, 16), Linear(), per_channel),
]

for (ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
inp = torch.rand(ishape)

quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()

input_name = "input"

runtime = get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
tvm_result = runtime.get_output(0).numpy()

runtime_int8_weight = get_tvm_runtime(
script_module, input_name, ishape, keep_quantized_weight=True
)
runtime_int8_weight.set_input(input_name, inp.numpy().copy())
runtime_int8_weight.run()
tvm_result_int8_weight = runtime_int8_weight.get_output(0).numpy()

tvm.testing.assert_allclose(tvm_result, tvm_result_int8_weight)

0 comments on commit 9eea017

Please sign in to comment.