Skip to content

Commit

Permalink
Refactor layout implementation (#491)
Browse files Browse the repository at this point in the history
Summary:
TODO

Test Plan:
TODO

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 16, 2024
1 parent 6e7cf71 commit aef7e09
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 61 deletions.
11 changes: 10 additions & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized,
LayoutType,
PlainLayoutType,
TensorCoreTiledLayoutType,
)

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_affine_quantized",
"LayoutType",
"PlainLayoutType",
"TensorCoreTiledLayoutType",
]
137 changes: 88 additions & 49 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,35 @@
_ATEN_OP_OR_TORCH_FN_TABLE,
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
)
from typing import ClassVar
from dataclasses import dataclass

aten = torch.ops.aten

@dataclass(frozen=True)
class PlainLayoutType(LayoutType):
pass

@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


def _aqt_is_int8(aqt):
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
return (
Expand Down Expand Up @@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor):
"""
Base class for the layout tensor for `AffineQuantizedTensor`
"""
# this should be set for each layout class during registration
extended_layout: Optional[str] = None
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_layout_type(self) -> LayoutType:
pass

@classmethod
Expand All @@ -64,9 +89,15 @@ def from_plain(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
pass

def __repr__(self):
int_data, scale, zero_point = self.get_plain()
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand Down Expand Up @@ -194,30 +225,17 @@ def from_float(
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
extended_layout: str = "plain",
# TODO: this is only for "tensor_core_tiled", need to figure out
# the proper API for this arg
inner_k_tiles: Optional[int] = None,
layout_type: LayoutType = PlainLayoutType(),
):
original_shape = input_float.shape
if extended_layout == "tensor_core_tiled":
orig_out_features, orig_in_features = input_float.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input_float = torch.nn.functional.pad(
input_float,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
input_float = layout_type.pre_process(input_float)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)

layout_cls_ctr = get_layout_tensor_constructor(extended_layout)
# TODO: this is temporary, need to come up with the proper UX
if extended_layout == "tensor_core_tiled":
layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles)
else:
layout_tensor = layout_cls_ctr(int_data, scale, zero_point)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
layout_tensor,
block_size,
Expand All @@ -229,8 +247,8 @@ def from_float(
)

@property
def extended_layout(self) -> str:
return self.layout_tensor.extended_layout
def layout_type(self) -> str:
return self.layout_tensor.layout_type

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down Expand Up @@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
def implements(aten_ops_or_torch_fn):
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)

def register_layout_cls(extended_layout: str):
return _register_layout_cls(AffineQuantizedTensor, extended_layout)
def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)

def get_layout_tensor_constructor(extended_layout: str):
return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout)
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class)

@register_layout_cls("plain")
@register_layout_cls(PlainLayoutType)
class PlainAQTLayout(AQTLayout):
"""
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
Expand All @@ -330,6 +348,7 @@ def __new__(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
kwargs = {}
kwargs["device"] = int_data.device
Expand All @@ -346,34 +365,39 @@ def __init__(
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], []
return ["int_data", "scale", "zero_point"], [self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
return cls(int_data, scale, zero_point)
layout_type, = tensor_attributes
return cls(int_data, scale, zero_point, layout_type)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.layout_type,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
self.layout_type,
)

@classmethod
Expand All @@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.int_data, self.scale, self.zero_point

def get_layout_type(self) -> LayoutType:
return self.layout_type

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
return cls(int_data, scale, zero_point)
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, zero_point, layout_type)

@register_layout_cls("tensor_core_tiled")
@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
"""
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
Expand All @@ -427,6 +456,7 @@ def __new__(
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
kwargs = {}
kwargs["device"] = packed_weight.device
Expand All @@ -443,31 +473,40 @@ def __init__(
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
layout_type: LayoutType,
):
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False
self.layout_type = layout_type

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"], [self.transposed]
return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
transposed, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed)
transposed, layout_type, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed, layout_type)

@classmethod
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType
):
assert isinstance(layout_type, TensorCoreTiledLayoutType)
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero, False)
return cls(packed_weight, scale_and_zero, False, layout_type)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
Expand All @@ -477,18 +516,15 @@ def to(self, *args, **kwargs):
return self.__class__(
self.packed_weight.to(device),
self.scale_and_zero.to(device),
self.transposed
self.transposed,
self.layout_type,
)

def _apply_fn_to_data(self, fn):
self.packed_weight = fn(self.packed_weight)
self.scale_and_zero = fn(self.scale_and_zero)
return self

def __repr__(self):
int_data, scale, zero_point = self.get_plain()
return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})"

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
quantize_affine,
Expand Down Expand Up @@ -542,6 +578,9 @@ def get_plain(self):
int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain)
return int_data, scale, zero

def get_layout_type(self) -> LayoutType:
return self.layout_type

def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Expand All @@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda and
input_is_int8 and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.extended_layout == "plain" and
weight_qtensor.extended_layout == "plain"
isinstance(input_tensor.layout_type, PlainLayoutType) and
isinstance(weight_qtensor.layout_type, PlainLayoutType)
):
#
# 1. do the matrix form of dot(X_i, W_j)
Expand Down Expand Up @@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.extended_layout == "tensor_core_tiled"
isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType)
):
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
Expand Down Expand Up @@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
weight_qtensor.extended_layout == "plain"
isinstance(weight_qtensor.layout_type, PlainLayoutType)
):
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
Expand Down
Loading

0 comments on commit aef7e09

Please sign in to comment.