Skip to content

Commit

Permalink
feat(qtensor): add MarlinQBitsTensor
Browse files Browse the repository at this point in the history
Adding more tests revealed a bug in the Marlin int4 kernel when the
weights and inputs are large enough.
Failing configurations are marked as xfail.
  • Loading branch information
dacorvo committed Oct 10, 2024
1 parent b688051 commit 852bb9c
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/quanto/tensor/weights/marlin/int4/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .packed import *
from .qbits import *
168 changes: 168 additions & 0 deletions optimum/quanto/tensor/weights/marlin/int4/qbits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast

import torch
from torch.autograd import Function

from ....function import QuantizedLinearFunction
from ....grouped import group, ungroup
from ....qtype import qtypes
from ...qbits import WeightQBitsTensor
from ..permutations import marlin_permute
from .packed import MarlinInt4PackedTensor


__all__ = ["MarlinInt4WeightQBitsTensor"]


class MarlinQBitsDequantizer(Function):
@staticmethod
def forward(ctx, t):
unpacked = t._data.unpack()
scale = t._scale
shift = t._shift
unpacked = group(unpacked, axis=0, group_size=t._group_size)
# Apply inverted permutations
scale = marlin_permute(scale, reverse=True)
shift = marlin_permute(shift, reverse=True)
n_scales = scale.numel()
scale = scale.t().reshape((n_scales, 1))
shift = shift.t().reshape((n_scales, 1))
# Shift is already scaled and negated
dqt = scale * unpacked + shift
return ungroup(dqt, axis=t.axis, orig_shape=t.shape)

@staticmethod
def backward(ctx, gO):
return gO


class MarlinQBitsLinearFunction(QuantizedLinearFunction):
@staticmethod
def forward(ctx, input, other, bias):
ctx.save_for_backward(input, other)
if type(input) is not torch.Tensor:
input = input.dequantize()
out_features, in_features = other.shape
output = torch.ops.quanto.gemm_f16i4_marlin(
input,
other._data._data,
other._scale,
other._shift,
other._workspace,
)
if bias is not None:
output = output + bias
return output


class MarlinInt4WeightQBitsTensor(WeightQBitsTensor):
@staticmethod
def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
assert data.device.type == "cuda"
assert data.device == scale.device
assert data.device == shift.device
return torch.Tensor._make_wrapper_subclass(
cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad
)

def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
assert axis == 0
out_features, in_features = size
if not isinstance(data, MarlinInt4PackedTensor):
assert type(data) is torch.Tensor
# Format data, scale and shift for optimized CUDA gemm
ungrouped = ungroup(data, axis=0, orig_shape=size)
data = MarlinInt4PackedTensor.pack(ungrouped)
scale = scale.reshape(out_features, in_features // group_size).t().contiguous()
shift = shift.reshape(out_features, in_features // group_size).t()
if not shift.dtype.is_floating_point:
# Integer shift must be scaled
shift = scale * shift
# Shift must be negated
shift = -shift.contiguous()
# Finally, apply scale and shift permutations
scale = marlin_permute(scale)
shift = marlin_permute(shift)
super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)
self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device)

def dequantize(self):
return MarlinQBitsDequantizer.apply(self)

def weight_qbits_tensor(self):
"""Convert back to a WeightQBitsTensor
This is required to make sure only standard packing is used when serializing.
"""
data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size)
scale = marlin_permute(self._scale, reverse=True)
shift = marlin_permute(self._shift, reverse=True)
n_scales = scale.numel()
scale = scale.t().reshape((n_scales, 1))
shift = -shift.t().reshape((n_scales, 1))
return WeightQBitsTensor(
self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift
)

def __tensor_flatten__(self):
inner_tensors = ["_data", "_scale", "_shift"]
# Since meta can be used for serialization, use only strings
meta = {
"qtype": self._qtype.name,
"axis": str(self._axis),
"group_size": str(self._group_size),
"size": str(list(self.size())),
"stride": str(list(self.stride())),
}
return inner_tensors, meta

@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert len(inner_tensors) == 3
assert len(meta) == 5
data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"]
# Meta should only contain strings, AST compatible except qtype
qtype = qtypes[meta["qtype"]]
axis = ast.literal_eval(meta["axis"])
group_size = ast.literal_eval(meta["group_size"])
size = ast.literal_eval(meta["size"])
stride = ast.literal_eval(meta["stride"])
return MarlinInt4WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Dispatch torch functions applied on this subtensor
This method is called whenever a torch function (such as `torch.nn.functional.linear`)
is called with at least one parameter coresponding to this subtensor:
- if a quantized implementation exists for the selected function, it is called,
- otherwise, the original implementation is called, deactivating further functional dispatch.
During the execution of the standard torch function, a second-level of dispatch will
happen, but this time directly on individual torch Tensor operations (mainly ATEN).
"""
kwargs = kwargs or {}
if func is torch.nn.functional.linear:

def qlinear(input, other, bias=None):
return MarlinQBitsLinearFunction.apply(input, other, bias)

return qlinear(*args, **kwargs)
# Defer to operations dispatcher
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
150 changes: 150 additions & 0 deletions test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from helpers import device_eq, random_qweight
from tensor.weights.weight_helpers import check_weight_qtensor_linear

from optimum.quanto import qint4
from optimum.quanto.library.extensions import is_extension_available
from optimum.quanto.tensor.weights import WeightQBitsTensor
from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor


@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available"
)
@pytest.mark.parametrize("in_features", [128, 256, 512, 1024])
@pytest.mark.parametrize("out_features", [128, 256, 512, 1024])
def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features):
qtype = qint4
group_size = 128
dtype = torch.float16
shape = (out_features, in_features)
device = torch.device("cuda")
qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)
# Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members
marlinqbt = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
assert marlinqbt.dtype == dtype
assert marlinqbt.qtype == qtype
assert marlinqbt.shape == shape
assert device_eq(marlinqbt.device, device)
# Verify the dequantized tensors are identical
assert torch.equal(marlinqbt.dequantize(), qbt.dequantize())
# Now verify that we can reconstruct the WeightQBitsTensor
new_qbt = marlinqbt.weight_qbits_tensor()
assert type(new_qbt) is WeightQBitsTensor
assert new_qbt.dtype == dtype
assert new_qbt.qtype == qtype
assert new_qbt.shape == shape
assert torch.equal(new_qbt._data, qbt._data)
assert torch.equal(new_qbt._scale, qbt._scale)
assert torch.equal(new_qbt._shift, qbt._shift)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_marlin_int4_weight_qbits_tensor_move(device):
qtype = qint4
group_size = 128
dtype = torch.float16
shape = (1024, 1024)
device = torch.device("cuda")
# Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda"))
marlinqbt = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
# Move to device, dequantize and compare
moved_qbt = marlinqbt.to(device)
assert isinstance(moved_qbt, WeightQBitsTensor)
if device.type != "cuda":
assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor
assert marlinqbt.dtype == moved_qbt.dtype
assert marlinqbt.qtype == moved_qbt.qtype
assert marlinqbt.shape == moved_qbt.shape
assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize())


def _test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
):
# Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA
qbt = random_qweight(
(out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda")
)
marlin_qweight = MarlinInt4WeightQBitsTensor(
qtype=qbt.qtype,
axis=qbt.axis,
group_size=qbt._group_size,
size=qbt.size(),
stride=qbt.stride(),
data=qbt._data.unpack(),
scale=qbt._scale,
shift=qbt._shift,
)
check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias)


@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [16, 32])
@pytest.mark.parametrize("in_features", [1024])
@pytest.mark.parametrize("out_features", [1024, 2048, 4096])
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias):
dtype = torch.float16
weight_qtype = qint4
group_size = 128
_test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias
)


@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)
@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [48, 64])
# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("in_features", [4096, 16384])
@pytest.mark.parametrize("out_features", [2048, 4096])
def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):
dtype = torch.float16
weight_qtype = qint4
group_size = 128
_test_marlin_int4_weight_qbits_tensor_linear(
dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False
)

0 comments on commit 852bb9c

Please sign in to comment.