From 0a13e6a343e187310e9469ec2c028f51f7ce88cb Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 3 Apr 2024 18:18:19 -0700 Subject: [PATCH 01/17] proof of concept for FSDP2 + NF4Tensor Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 380 +++++++++++++++++++++++++++++------- 1 file changed, 310 insertions(+), 70 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index ea45a6c0d..2c0f7bfe6 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from torch import Tensor +import math aten = torch.ops.aten @@ -35,88 +36,323 @@ def decorator(func): return decorator - -@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) -def noop_detach(func, *args, **kwargs): - return args[0][0] - - -@implements([torch.ops.aten._to_copy.default]) -def _to_copy(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return func(args[0][0].t()).t() - return args[0][0].get_original_weight().to(args[1]["dtype"]) - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() - return args[0][0].get_original_weight().to(args[0][1]) - - -@implements([torch.ops.aten.t.default]) -def t_default(func, *args, **kwargs): - a = args[0][0] +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + # nn.Parameter need detach + quantized_scalers = aten_op(args[0].quantized_scalers, *args[1:], **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *args[1:], **kwargs) + quantized_data = aten_op(args[0].quantized_data, *args[1:], **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs) + nf4 = aten_op(args[0].nf4, *args[1:], **kwargs) tensor_meta = SubclassTensorArgs( - a.size(), - (a.stride(1), a.stride(0)), - a.storage_offset(), - torch.bits2x4, - a.device, - a.requires_grad, + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, ) - b = NF4Tensor( + return NF4Tensor( tensor_meta, - a.block_size, - a.n_blocks, - a.scaler_block_size, - a.quantized_scalers, - a.quantization_factor, - a.scaler_mean, - a.quantized_data, - a.nf4, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, ) - return b +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + # torch.chunk + # TODO: find if there are other args/kwargs in aten.split + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args" + # TODO: assert on dim-0 sharding. how to get dim from torch.chunk? + num_chunks = args[0].size(0) // args[1] + + # TODO: assert numel % num_chunks == 0 + quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs) + quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs) + quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs) + + + assert len(args) == 2, "only support 2d because of tensor meta" + return [ + NF4Tensor( + SubclassTensorArgs( + (args[0].size(0) // num_chunks, args[0].size(1)), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) for quantized_scalers, quantization_factor, quantized_data in zip( + quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks + ) + ] -@implements([torch.ops.aten.mm.default]) -def mm_default(func, *args, **kwargs): - return linear_nf4(args[0][0], args[0][1]) +@implements( + [ + aten.new_zeros.default, + ] +) +def nf4_new_zeros(aten_op, args, kwargs=None): + assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D" + assert args[0].numel() % math.prod(args[1]) == 0 + ratio = args[0].numel() // math.prod(args[1]) + + assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" + quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs) + + assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" + quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs) + + assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" + quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs) + + + return NF4Tensor( + SubclassTensorArgs( + (args[1][0], args[1][1]), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers_new_zeros, + quantization_factor_new_zeros, + args[0].scaler_mean, + quantized_data_new_zeros, + args[0].nf4, + ) +@implements( + [ + aten.slice.Tensor, + ] +) +def nf4_slice(aten_op, args, kwargs=None): + assert len(args) == 4 + assert args[1] == 0, f"only support dim=0 but got dim={args[1]}" + # TODO: maybe relax? + assert args[2] == 0, f"only support start=0 but got start={args[2]}" + assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}" + return NF4Tensor( + SubclassTensorArgs( + args[0].size(), + args[0].stride(), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + args[0].quantized_data, + args[0].nf4, + ) @implements( [ aten.copy_.default, ] ) -def copy_(func, *args, **kwargs): - original: NF4Tensor = args[0][0] - copy_in: torch.Tensor = args[0][1] - - # Base Case - - if same_metadata(original, copy_in): - original_tensors = original.__tensor_flatten__()[0] - for tensor_name in original_tensors: - getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) - return - - # Convert Non NF4Tensor into NF4 for copy in - if not isinstance(copy_in, NF4Tensor): - copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size - ) - return original.copy_(copy_in_nf4) +def nf4_copy_(aten_op, args, kwargs=None): + assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" + quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) + quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) + nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) + tensor_meta = SubclassTensorArgs( + args[1].size(), + args[1].stride(), + args[1].storage_offset(), + args[1].dtype, + args[1].device, + args[1].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[1].block_size, + args[1].n_blocks, + args[1].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ) + +@implements( + [ + aten.view.default, + ] +) +def nf4_view(aten_op, args, kwargs=None): + assert len(args) == 2, args[1] == -1 + quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) + quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) + tensor_meta = SubclassTensorArgs( + [args[0].numel()], + (1, ), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + quantized_scalers, + quantization_factor, + args[0].scaler_mean, + quantized_data, + args[0].nf4, + ) - # Other Tensor is not a NF4Tensor - full_precision = copy_in.get_original_weight() - same_meta_nf4 = NF4Tensor.from_tensor( - full_precision, original.block_size, original.scaler_block_size +@implements( + [ + aten.as_strided.default, + ] +) +def nf4_as_strided(aten_op, args, kwargs=None): + assert len(args[1]) == 2 and math.prod(args[1]) == args[0].numel(), "only support same numel" + assert args[2] == [args[1][1], 1], f"only support stride {[args[1][1], 1]}" + assert args[0].storage_offset() == args[3], f"only support same storage offset" + return NF4Tensor( + SubclassTensorArgs( + torch.Size(args[1]), + tuple(args[2]), + args[0].storage_offset(), + args[0].dtype, + args[0].device, + args[0].requires_grad, + ), + args[0].block_size, + args[0].n_blocks, + args[0].scaler_block_size, + args[0].quantized_scalers, + args[0].quantization_factor, + args[0].scaler_mean, + args[0].quantized_data, + args[0].nf4, ) - return original.copy_(same_meta_nf4) + +# @implements([torch.ops.aten.detach]) +# def noop_detach(func, *args, **kwargs): +# assert False +# return args[0][0] + + +# @implements([torch.ops.aten._to_copy.default]) +# def _to_copy(func, *args, **kwargs): +# if not args[0][0].is_contiguous(): +# assert args[0][0].t().is_contiguous() +# return func(args[0][0].t()).t() +# return args[0][0].get_original_weight().to(args[1]["dtype"]) + + +# @implements([torch.ops.aten.to.dtype]) +# def to_dtype(func, *args, **kwargs): +# if not args[0][0].is_contiguous(): +# assert args[0][0].t().is_contiguous() +# return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() +# return args[0][0].get_original_weight().to(args[0][1]) + + +# @implements([torch.ops.aten.t.default]) +# def t_default(func, *args, **kwargs): +# a = args[0][0] +# tensor_meta = SubclassTensorArgs( +# a.size(), +# (a.stride(1), a.stride(0)), +# a.storage_offset(), +# torch.bits2x4, +# # a.dtype, +# a.device, +# a.requires_grad, +# ) +# b = NF4Tensor( +# tensor_meta, +# a.block_size, +# a.n_blocks, +# a.scaler_block_size, +# a.quantized_scalers, +# a.quantization_factor, +# a.scaler_mean, +# a.quantized_data, +# a.nf4, +# ) +# return b + + +# @implements([torch.ops.aten.mm.default]) +# def mm_default(func, *args, **kwargs): +# return linear_nf4(args[0][0], args[0][1]) + + +# TODO: merge with above +# @implements( +# [ +# aten.copy_.default, +# ] +# ) +# def copy_(func, *args, **kwargs): +# original: NF4Tensor = args[0][0] +# copy_in: torch.Tensor = args[0][1] + +# # Base Case + +# if same_metadata(original, copy_in): +# original_tensors = original.__tensor_flatten__()[0] +# for tensor_name in original_tensors: +# getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) +# return + +# # Convert Non NF4Tensor into NF4 for copy in +# if not isinstance(copy_in, NF4Tensor): +# copy_in_nf4 = NF4Tensor.from_tensor( +# copy_in, original.block_size, original.scaler_block_size +# ) +# return original.copy_(copy_in_nf4) + +# # Other Tensor is not a NF4Tensor +# full_precision = copy_in.get_original_weight() +# same_meta_nf4 = NF4Tensor.from_tensor( +# full_precision, original.block_size, original.scaler_block_size +# ) +# return original.copy_(same_meta_nf4) @dataclass @@ -187,7 +423,8 @@ def __new__( tensor_meta.original_strides, tensor_meta.storage_offset, # Picked some floating dtype, but we need dtype extensibility - dtype=torch.float8_e5m2fnuz, + dtype=tensor_meta.dtype, + # dtype=torch.float8_e5m2fnuz, device=tensor_meta.device, requires_grad=tensor_meta.requires_grad, ) @@ -562,7 +799,9 @@ class LinearNF4(torch.autograd.Function): def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" ctx.nf4_weight = weight - return F.linear(input, weight.to(input.dtype)) + assert input.dtype == torch.bfloat16 and input.dtype == weight.dtype + return F.linear(input, weight.get_original_weight()) + # return F.linear(input, weight.to(input.dtype)) @staticmethod @@ -571,7 +810,8 @@ def forward(ctx, input: torch.Tensor, weight: NF4Tensor): def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight - return grad_output @ weight.to(grad_output.dtype), None + # return grad_output @ weight.to(grad_output.dtype), None + return grad_output @ weight.get_original_weight(), None def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: From 818054015f1648a74d1ba8cff6b743321350682b Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 11 Apr 2024 11:53:52 -0700 Subject: [PATCH 02/17] fsdp extention for tensor subclass Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 61 ++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 2c0f7bfe6..7745483c9 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -11,7 +11,7 @@ c10d_functional = torch.ops.c10d_functional -from typing import Any +from typing import Any, Optional, Tuple, Union, List NF4_OPS_TABLE: Dict[Any, Any] = {} @@ -790,6 +790,65 @@ def allowed_subclasses(type): __torch_function__ = torch._C._disabled_torch_function_impl + def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: + return ( + self.quantized_scalers, + self.quantization_factor, + self.quantized_data, + ), ( + SubclassTensorArgs( + self.size(), + self.stride(), + self.storage_offset(), + self.dtype, + self.device, + self.requires_grad, + ), + self.block_size, + self.n_blocks, + self.scaler_block_size, + self.scaler_mean, + self.nf4, + ) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: + (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4) = metadata + world_size = quantized_data.numel() * 2 // math.prod(tensor_meta.original_shape) + len(tensor_meta.original_shape) == 2, "only support 2D shape" + tensor_meta.original_shape = torch.Size((tensor_meta.original_shape[0] * world_size, tensor_meta.original_shape[1])) + if out is not None: + # TODO: add param dtype + assert isinstance(out, NF4Tensor), f"{type(out)}" + assert ( + quantized_scalers.untyped_storage().data_ptr() + == out.quantized_scalers.untyped_storage().data_ptr() and + quantization_factor.untyped_storage().data_ptr() + == out.quantization_factor.untyped_storage().data_ptr() and + quantized_data.untyped_storage().data_ptr() + == out.quantized_data.untyped_storage().data_ptr() + ), f"Expects out's data to be the all-gather output" + return + + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ), (quantized_scalers, quantization_factor, quantized_data) + class LinearNF4(torch.autograd.Function): @staticmethod From 95b03e19e8a2118c70bb279c56967c37440aeca8 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 15 Apr 2024 16:30:11 -0700 Subject: [PATCH 03/17] support fp32 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 49 +++++++++++++++++++++---------------- torchao/dtypes/nf4tensor.py | 16 ++++++------ 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index c71fcd25b..16ede73df 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -9,6 +9,7 @@ import io from collections import OrderedDict import torchao +import pytest bnb_available = False @@ -24,10 +25,10 @@ ) -def _build_input_weight(embed_dim: int, device: torch.device): +def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype): torch.manual_seed(0) input_weight = torch.empty( - embed_dim, embed_dim, device=device, dtype=torch.bfloat16 + embed_dim, embed_dim, device=device, dtype=dtype ) input_weight.normal_(0, 1) return input_weight @@ -45,7 +46,7 @@ def _build_bnb_linear(input_weight, device): return bnb_linear -class TestNF4Linear(TestCase): +class TestNF4Linear(): class TestMod(nn.Module): def __init__(self, tensor, block_size, scaler_block_size): super().__init__() @@ -57,42 +58,46 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict): buffer.seek(0) return buffer - def test_register_nf4_as_param(self): - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_register_nf4_as_param(self, dtype: torch.dtype): + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) # Would raise if nn.Parameter registration fails, such as no detach() # impl when calling __torch_dispatch__ param = torch.nn.Parameter(nf4_tensor, requires_grad=False) assert not param.requires_grad - def test_output_bf16(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_output_dtype_match(self, dtype:torch.dtype): # Test to ensure W4 A16 produces A16 - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) out = linear_nf4(input=inp, weight=nf4_tensor) - assert out.dtype == torch.bfloat16 + assert out.dtype == dtype - def test_backward_bf16(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_backward_dtype_match(self, dtype:torch.dtype): # Test to ensure backward pass gives activation a bf16 gradient and no gradient # to the linear's weight, as it is frozen. - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) linear_nf4(inp, nf4_tensor).sum().backward() - assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 + assert inp.grad is not None and inp.grad.dtype == dtype assert nf4_tensor.grad is None @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_reconstruction_qlora_vs_bnb(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 torch.manual_seed(0) device = "cuda" embed_dim = 512 - input_weight = _build_input_weight(embed_dim, device) + input_weight = _build_input_weight(embed_dim, device, dtype) nf4_weight = to_nf4(input_weight) bnb_linear = _build_bnb_linear(input_weight, device) bnb_reconstruction = bnb_linear( - torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device) + torch.eye(embed_dim, embed_dim, dtype=dtype, device=device) ) bnb_diff = (bnb_reconstruction.T - input_weight).abs().max() nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max() @@ -104,7 +109,8 @@ def test_reconstruction_qlora_vs_bnb(self): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_nf4_bnb_linear(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_nf4_bnb_linear(self, dtype: torch.dtype): """ This test ensures that nf4_linear is "no worse" than BNB by ensuring the error compared to a bf16 linear is not more than BNB's implementation. @@ -112,11 +118,11 @@ def test_nf4_bnb_linear(self): torch.manual_seed(0) dim = 512 device = "cuda" - input_weight = _build_input_weight(dim, device) + input_weight = _build_input_weight(dim, device, dtype) nf4_weight = to_nf4(input_weight) bnb_linear = _build_bnb_linear(input_weight, device) - inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") + inp = torch.randn(2, 512, dtype=dtype, device="cuda") out_nf4 = linear_nf4(inp, nf4_weight).sum() out_bnb = bnb_linear(inp).sum() @@ -185,8 +191,9 @@ def test_to_bfloat16(self): assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16 @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_smoketest_linear(self): - a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda') + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_smoketest_linear(self, dtype: torch.dtype): + a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) out1 = torch.nn.functional.linear(inp, a) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 7745483c9..4db41dca0 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -461,11 +461,11 @@ def from_tensor( scaler_block_size: int, ): assert inpt_tensor.dim() <= 2 - assert inpt_tensor.dtype == torch.bfloat16 + # assert inpt_tensor.dtype == torch.bfloat16 assert ( inpt_tensor.numel() % block_size == 0 ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" - assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" + # assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not inpt_tensor.requires_grad, "Input tensor must not require grad" @@ -491,7 +491,7 @@ def from_tensor( 1.0000, ], device=device, - dtype=torch.bfloat16, + dtype=inpt_tensor.dtype, ) n_blocks = inpt_tensor.numel() // block_size # Double quantization @@ -607,7 +607,7 @@ def dequantize_scalers( n_scaler_blocks = inpt_tensor.numel() // scaler_block_size inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to( - torch.bfloat16 + self.dtype ) + self.scaler_mean return dequantized @@ -858,7 +858,8 @@ class LinearNF4(torch.autograd.Function): def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" ctx.nf4_weight = weight - assert input.dtype == torch.bfloat16 and input.dtype == weight.dtype + # assert input.dtype == torch.bfloat16 and input.dtype == weight.dtype + assert input.dtype == weight.dtype return F.linear(input, weight.get_original_weight()) # return F.linear(input, weight.to(input.dtype)) @@ -884,5 +885,6 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): - tensor1 = tensor.to(torch.bfloat16) - return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size) + # tensor1 = tensor.to(torch.bfloat16) + # return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size) + return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) From 38461b31eba7bb8d6a97b385b7c59458437d5527 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 16 Apr 2024 14:45:26 -0700 Subject: [PATCH 04/17] UNIT TEST FOR STATE DICT Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 16ede73df..8e2918ece 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -134,13 +134,15 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype): assert err_bnb < 0.5 * dim @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_bfloat16(self): + # @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + def test_load_from_state_dicts(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) - bf16_dummy_dict = {"param": inpt_tensor} - base_mod.load_state_dict(bf16_dummy_dict) + dummy_dict = {"param": inpt_tensor} + base_mod.load_state_dict(dummy_dict) assert base_mod.param.block_size == 32 assert base_mod.param.scaler_block_size == 2 From bc7a7649b34312b5619de433fe66841feee6c3b0 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 16 Apr 2024 20:39:55 -0700 Subject: [PATCH 05/17] implement to Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 13 +- torchao/dtypes/nf4tensor.py | 243 +++++++++++++++++++++--------------- 2 files changed, 147 insertions(+), 109 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 8e2918ece..723c5167e 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -134,8 +134,7 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype): assert err_bnb < 0.5 * dim @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - # @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) - @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_state_dicts(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) @@ -148,9 +147,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype): assert base_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_nf4_same_meta(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_load_from_nf4_same_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) @@ -161,9 +161,10 @@ def test_load_from_nf4_same_meta(self): assert other_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_nf4_diff_meta(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_load_from_nf4_diff_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(128, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4db41dca0..d36867c75 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,5 +1,7 @@ from dataclasses import dataclass from typing import Dict, Tuple +import functools +from torch.overrides import get_default_nowrap_functions import torch import torch.nn.functional as F @@ -14,6 +16,29 @@ from typing import Any, Optional, Tuple, Union, List NF4_OPS_TABLE: Dict[Any, Any] = {} +NF4_TORCH_FUNCTIONS = {} + +def implements_torch_function(torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + NF4_TORCH_FUNCTIONS[torch_function] = func + return func + return decorator + +@implements_torch_function(torch.Tensor.to) +def function_to_dtype(*args, **kwargs): + if not args[0].is_contiguous(): + breakpoint() + assert args[0].t().is_contiguous() + return torch.ops.aten.to.dtype(args[0].t(), args[1]).t() + return args[0].get_original_weight().to(args[1]) + +# @implements_torch_function([torch.Tensor.to]) +# def to_dtype(func, *args, **kwargs): +# if not args[0][0].is_contiguous(): +# assert args[0][0].t().is_contiguous() +# return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() +# return args[0][0].get_original_weight().to(args[0][1]) def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): @@ -186,30 +211,49 @@ def nf4_slice(aten_op, args, kwargs=None): ) def nf4_copy_(aten_op, args, kwargs=None): assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" - quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) - quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) - quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) - scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) - nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) - tensor_meta = SubclassTensorArgs( - args[1].size(), - args[1].stride(), - args[1].storage_offset(), - args[1].dtype, - args[1].device, - args[1].requires_grad, - ) - return NF4Tensor( - tensor_meta, - args[1].block_size, - args[1].n_blocks, - args[1].scaler_block_size, - quantized_scalers, - quantization_factor, - scaler_mean, - quantized_data, - nf4, + # TODO: use original and copy_in in same_meta + original: NF4Tensor = args[0] + copy_in: torch.Tensor = args[1] + + if same_metadata(original, copy_in): + quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) + quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) + quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) + scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) + nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) + tensor_meta = SubclassTensorArgs( + args[1].size(), + args[1].stride(), + args[1].storage_offset(), + args[1].dtype, + args[1].device, + args[1].requires_grad, + ) + return NF4Tensor( + tensor_meta, + args[1].block_size, + args[1].n_blocks, + args[1].scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ) + + # Convert Non NF4Tensor into NF4 for copy in + if not isinstance(copy_in, NF4Tensor): + copy_in_nf4 = NF4Tensor.from_tensor( + copy_in, original.block_size, original.scaler_block_size + ) + return original.copy_(copy_in_nf4) + + # Other Tensor is not a NF4Tensor + full_precision = copy_in.get_original_weight() + same_meta_nf4 = NF4Tensor.from_tensor( + full_precision, original.block_size, original.scaler_block_size ) + return original.copy_(same_meta_nf4) @implements( [ @@ -269,90 +313,59 @@ def nf4_as_strided(aten_op, args, kwargs=None): args[0].nf4, ) -# @implements([torch.ops.aten.detach]) -# def noop_detach(func, *args, **kwargs): -# assert False -# return args[0][0] +@implements([torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + assert False + return args[0][0] -# @implements([torch.ops.aten._to_copy.default]) -# def _to_copy(func, *args, **kwargs): -# if not args[0][0].is_contiguous(): -# assert args[0][0].t().is_contiguous() -# return func(args[0][0].t()).t() -# return args[0][0].get_original_weight().to(args[1]["dtype"]) +@implements([torch.ops.aten._to_copy.default]) +def _to_copy(func, *args, **kwargs): + breakpoint() + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return func(args[0][0].t()).t() + return args[0][0].get_original_weight().to(args[1]["dtype"]) -# @implements([torch.ops.aten.to.dtype]) -# def to_dtype(func, *args, **kwargs): -# if not args[0][0].is_contiguous(): -# assert args[0][0].t().is_contiguous() -# return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() -# return args[0][0].get_original_weight().to(args[0][1]) +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, *args, **kwargs): + breakpoint() + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() + return args[0][0].get_original_weight().to(args[0][1]) -# @implements([torch.ops.aten.t.default]) -# def t_default(func, *args, **kwargs): -# a = args[0][0] -# tensor_meta = SubclassTensorArgs( -# a.size(), -# (a.stride(1), a.stride(0)), -# a.storage_offset(), -# torch.bits2x4, -# # a.dtype, -# a.device, -# a.requires_grad, -# ) -# b = NF4Tensor( -# tensor_meta, -# a.block_size, -# a.n_blocks, -# a.scaler_block_size, -# a.quantized_scalers, -# a.quantization_factor, -# a.scaler_mean, -# a.quantized_data, -# a.nf4, -# ) -# return b - - -# @implements([torch.ops.aten.mm.default]) -# def mm_default(func, *args, **kwargs): -# return linear_nf4(args[0][0], args[0][1]) - - -# TODO: merge with above -# @implements( -# [ -# aten.copy_.default, -# ] -# ) -# def copy_(func, *args, **kwargs): -# original: NF4Tensor = args[0][0] -# copy_in: torch.Tensor = args[0][1] - -# # Base Case - -# if same_metadata(original, copy_in): -# original_tensors = original.__tensor_flatten__()[0] -# for tensor_name in original_tensors: -# getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) -# return - -# # Convert Non NF4Tensor into NF4 for copy in -# if not isinstance(copy_in, NF4Tensor): -# copy_in_nf4 = NF4Tensor.from_tensor( -# copy_in, original.block_size, original.scaler_block_size -# ) -# return original.copy_(copy_in_nf4) - -# # Other Tensor is not a NF4Tensor -# full_precision = copy_in.get_original_weight() -# same_meta_nf4 = NF4Tensor.from_tensor( -# full_precision, original.block_size, original.scaler_block_size -# ) -# return original.copy_(same_meta_nf4) +@implements([torch.ops.aten.t.default]) +def t_default(func, *args, **kwargs): + a = args[0][0] + tensor_meta = SubclassTensorArgs( + a.size(), + (a.stride(1), a.stride(0)), + a.storage_offset(), + # torch.bits2x4, + a.dtype, + a.device, + a.requires_grad, + ) + b = NF4Tensor( + tensor_meta, + a.block_size, + a.n_blocks, + a.scaler_block_size, + a.quantized_scalers, + a.quantization_factor, + a.scaler_mean, + a.quantized_data, + a.nf4, + ) + return b + + +@implements([torch.ops.aten.mm.default]) +def mm_default(func, *args, **kwargs): + return linear_nf4(args[0][0], args[0][1]) @dataclass @@ -759,6 +772,32 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride def __str__(self): return self.to(torch.float32).__str__() + # Do not force the Float8Tensor type on the returned tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in NF4_TORCH_FUNCTIONS: + return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + + # if not all( + # issubclass(t, (torch.Tensor, NF4Tensor)) + # for t in types + # ): + # return NotImplemented + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + with torch._C.DisableTorchFunctionSubclass(): + ret = func(*args, **kwargs) + if func in get_default_nowrap_functions(): + return ret + else: + return torch._tensor._convert(ret, cls) + + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment @@ -786,9 +825,7 @@ def allowed_subclasses(type): f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) - # Do not force the Float8Tensor type on the returned tensor - __torch_function__ = torch._C._disabled_torch_function_impl def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: return ( From 8b1d037f416192afd6e2e6029e2cda744f5a70e9 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 11:38:00 -0700 Subject: [PATCH 06/17] remove torch.override from torch function Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 40 ++++++++++++------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index d36867c75..df69bfe09 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Dict, Tuple import functools -from torch.overrides import get_default_nowrap_functions import torch import torch.nn.functional as F @@ -772,32 +771,6 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride def __str__(self): return self.to(torch.float32).__str__() - # Do not force the Float8Tensor type on the returned tensor - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - if func in NF4_TORCH_FUNCTIONS: - return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) - - # if not all( - # issubclass(t, (torch.Tensor, NF4Tensor)) - # for t in types - # ): - # return NotImplemented - if not all(issubclass(cls, t) for t in types): - return NotImplemented - - with torch._C.DisableTorchFunctionSubclass(): - ret = func(*args, **kwargs) - if func in get_default_nowrap_functions(): - return ret - else: - return torch._tensor._convert(ret, cls) - - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): """TODO we are not supporting torch dispatch at the moment @@ -826,6 +799,19 @@ def allowed_subclasses(type): ) + # Do not force the Float8Tensor type on the returned tensor + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func in NF4_TORCH_FUNCTIONS: + return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: return ( From 7ff68550047cc221f32012fca8cc0ca8ba146315 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 12:19:36 -0700 Subject: [PATCH 07/17] use dtype in compile unit test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 16 ++++++++-------- torchao/dtypes/nf4tensor.py | 23 ++--------------------- 2 files changed, 10 insertions(+), 29 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 723c5167e..d82cbde61 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -203,14 +203,14 @@ def test_smoketest_linear(self, dtype: torch.dtype): out2 = torch.nn.functional.linear(inp, a_nf4) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_smoketest_linear_compile(self): - for dtype in [torch.bfloat16, torch.float16]: - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: - self.skipTest("test requires SM capability of at least (8, 0).") - a = torch.randn(32, 32, dtype=dtype, device='cuda') - a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) - inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) - out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_smoketest_linear_compile(self, dtype: torch.dtype): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: + self.skipTest("test requires SM capability of at least (8, 0).") + a = torch.randn(32, 32, dtype=dtype, device='cuda') + a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) + inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) + out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index df69bfe09..746851bc6 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -26,19 +26,8 @@ def decorator(func): @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): - if not args[0].is_contiguous(): - breakpoint() - assert args[0].t().is_contiguous() - return torch.ops.aten.to.dtype(args[0].t(), args[1]).t() return args[0].get_original_weight().to(args[1]) -# @implements_torch_function([torch.Tensor.to]) -# def to_dtype(func, *args, **kwargs): -# if not args[0][0].is_contiguous(): -# assert args[0][0].t().is_contiguous() -# return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() -# return args[0][0].get_original_weight().to(args[0][1]) - def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor) @@ -320,7 +309,6 @@ def noop_detach(func, *args, **kwargs): @implements([torch.ops.aten._to_copy.default]) def _to_copy(func, *args, **kwargs): - breakpoint() if not args[0][0].is_contiguous(): assert args[0][0].t().is_contiguous() return func(args[0][0].t()).t() @@ -329,7 +317,6 @@ def _to_copy(func, *args, **kwargs): @implements([torch.ops.aten.to.dtype]) def to_dtype(func, *args, **kwargs): - breakpoint() if not args[0][0].is_contiguous(): assert args[0][0].t().is_contiguous() return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() @@ -881,10 +868,7 @@ class LinearNF4(torch.autograd.Function): def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" ctx.nf4_weight = weight - # assert input.dtype == torch.bfloat16 and input.dtype == weight.dtype - assert input.dtype == weight.dtype - return F.linear(input, weight.get_original_weight()) - # return F.linear(input, weight.to(input.dtype)) + return F.linear(input, weight.to(input.dtype)) @staticmethod @@ -893,8 +877,7 @@ def forward(ctx, input: torch.Tensor, weight: NF4Tensor): def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" weight: NF4Tensor = ctx.nf4_weight - # return grad_output @ weight.to(grad_output.dtype), None - return grad_output @ weight.get_original_weight(), None + return grad_output @ weight.to(grad_output.dtype), None def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: @@ -908,6 +891,4 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): - # tensor1 = tensor.to(torch.bfloat16) - # return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size) return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) From d9bcf7101aca62d3f69759a9cfd87713dc022ccf Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 13:51:16 -0700 Subject: [PATCH 08/17] add dtype in all unit test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index d82cbde61..2637a655e 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -45,13 +45,12 @@ def _build_bnb_linear(input_weight, device): bnb_linear.to(device) return bnb_linear +class TestMod(nn.Module): + def __init__(self, tensor, block_size, scaler_block_size): + super().__init__() + self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size)) class TestNF4Linear(): - class TestMod(nn.Module): - def __init__(self, tensor, block_size, scaler_block_size): - super().__init__() - self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size)) - def save_state_dict_to_buffer(self, state_dict: OrderedDict): buffer = io.BytesIO() torch.save(state_dict, buffer) @@ -138,7 +137,7 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype): def test_load_from_state_dicts(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) - base_mod = self.TestMod(inpt_tensor, 32, 2) + base_mod = TestMod(inpt_tensor, 32, 2) dummy_dict = {"param": inpt_tensor} base_mod.load_state_dict(dummy_dict) @@ -151,11 +150,11 @@ def test_load_from_state_dicts(self, dtype: torch.dtype): def test_load_from_nf4_same_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) - base_mod = self.TestMod(inpt_tensor, 32, 2) + base_mod = TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) - other_mod = self.TestMod(inpt_tensor, 32, 2) + other_mod = TestMod(inpt_tensor, 32, 2) other_mod.load_state_dict(torch.load(saved_state_dict)) assert other_mod.param.block_size == 32 assert other_mod.param.scaler_block_size == 2 @@ -165,33 +164,35 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype): def test_load_from_nf4_diff_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(128, device='cuda', dtype=dtype) - base_mod = self.TestMod(inpt_tensor, 32, 2) + base_mod = TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) - other_mod = self.TestMod(inpt_tensor, 64, 1) + other_mod = TestMod(inpt_tensor, 64, 1) other_mod.load_state_dict(torch.load(saved_state_dict)) assert other_mod.param.block_size == 64 assert other_mod.param.scaler_block_size == 1 - def test_to_copy(self): + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_to_copy(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, device='cpu') inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) - inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16) - torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13) + nf4_to_dtype = inpt_tensor_nf4.to(dtype) + torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) if torch.cuda.is_available(): inpt_tensor = torch.rand(128, device='cuda') inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) - inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16) - torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13) + nf4_to_dtype = inpt_tensor_nf4.to(dtype) + torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) - def test_to_bfloat16(self): - inpt_tensor = torch.rand(128, dtype=torch.bfloat16) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_to_dtype(self, dtype: torch.dtype): + inpt_tensor = torch.rand(128, dtype=dtype) inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) assert type(inpt_tensor_nf4) != torch.Tensor - assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor - assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16 + assert type(inpt_tensor_nf4.to(dtype)) == torch.Tensor + assert inpt_tensor_nf4.to(dtype).dtype == dtype @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) From 923bef27498a8335eed1b209d0b3e602dc86b511 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 14:03:30 -0700 Subject: [PATCH 09/17] keep original dtype Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 376 +++++------------------------------- 1 file changed, 47 insertions(+), 329 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 746851bc6..efb327af8 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -5,7 +5,6 @@ import torch import torch.nn.functional as F from torch import Tensor -import math aten = torch.ops.aten @@ -15,18 +14,6 @@ from typing import Any, Optional, Tuple, Union, List NF4_OPS_TABLE: Dict[Any, Any] = {} -NF4_TORCH_FUNCTIONS = {} - -def implements_torch_function(torch_function): - def decorator(func): - functools.update_wrapper(func, torch_function) - NF4_TORCH_FUNCTIONS[torch_function] = func - return func - return decorator - -@implements_torch_function(torch.Tensor.to) -def function_to_dtype(*args, **kwargs): - return args[0].get_original_weight().to(args[1]) def same_metadata(a: "NF4Tensor", b: "NF4Tensor"): @@ -49,261 +36,10 @@ def decorator(func): return decorator -@implements( - [ - aten.detach.default, - ] -) -def nf4_detach(aten_op, args, kwargs=None): - # nn.Parameter need detach - quantized_scalers = aten_op(args[0].quantized_scalers, *args[1:], **kwargs) - quantization_factor = aten_op(args[0].quantization_factor, *args[1:], **kwargs) - quantized_data = aten_op(args[0].quantized_data, *args[1:], **kwargs) - scaler_mean = aten_op(args[0].scaler_mean, *args[1:], **kwargs) - nf4 = aten_op(args[0].nf4, *args[1:], **kwargs) - tensor_meta = SubclassTensorArgs( - args[0].size(), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ) - return NF4Tensor( - tensor_meta, - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - quantized_scalers, - quantization_factor, - scaler_mean, - quantized_data, - nf4, - ) - -@implements( - [ - aten.split.Tensor, - ] -) -def nf4_split(aten_op, args, kwargs=None): - # torch.chunk - # TODO: find if there are other args/kwargs in aten.split - assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.split.Tensor with 2 args" - # TODO: assert on dim-0 sharding. how to get dim from torch.chunk? - num_chunks = args[0].size(0) // args[1] - - # TODO: assert numel % num_chunks == 0 - quantized_scalers_chunks = aten_op(args[0].quantized_scalers, args[0].quantized_scalers.numel() // num_chunks, **kwargs) - quantization_factor_chunks = aten_op(args[0].quantization_factor, args[0].quantization_factor.numel() // num_chunks, **kwargs) - quantized_data_chunks = aten_op(args[0].quantized_data, args[0].quantized_data.numel() // num_chunks, **kwargs) - - - assert len(args) == 2, "only support 2d because of tensor meta" - return [ - NF4Tensor( - SubclassTensorArgs( - (args[0].size(0) // num_chunks, args[0].size(1)), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - quantized_scalers, - quantization_factor, - args[0].scaler_mean, - quantized_data, - args[0].nf4, - ) for quantized_scalers, quantization_factor, quantized_data in zip( - quantized_scalers_chunks, quantization_factor_chunks, quantized_data_chunks - ) - ] - -@implements( - [ - aten.new_zeros.default, - ] -) -def nf4_new_zeros(aten_op, args, kwargs=None): - assert len(args[0].shape) == 2 and len(args[1]) == 2, "only support new zeros on 2D" - assert args[0].numel() % math.prod(args[1]) == 0 - ratio = args[0].numel() // math.prod(args[1]) - - assert args[0].quantized_scalers.size(0) % ratio == 0, f"quantized_scalers.numel() must be divisible by {ratio}" - quantized_scalers_new_zeros = aten_op(args[0].quantized_scalers, [args[0].quantized_scalers.size(0) // ratio], **kwargs) - - assert args[0].quantization_factor.size(0) % ratio == 0, f"quantization_factor.size(0) must be divisible by {ratio}" - quantization_factor_new_zeros = aten_op(args[0].quantization_factor, [args[0].quantization_factor.size(0) // ratio], **kwargs) - - assert args[0].quantized_data.size(0) % ratio == 0, f"quantized_data.size(0) must be divisible by {ratio}" - quantized_data_new_zeros = aten_op(args[0].quantized_data, [args[0].quantized_data.size(0) // ratio], **kwargs) - - - return NF4Tensor( - SubclassTensorArgs( - (args[1][0], args[1][1]), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - quantized_scalers_new_zeros, - quantization_factor_new_zeros, - args[0].scaler_mean, - quantized_data_new_zeros, - args[0].nf4, - ) - -@implements( - [ - aten.slice.Tensor, - ] -) -def nf4_slice(aten_op, args, kwargs=None): - assert len(args) == 4 - assert args[1] == 0, f"only support dim=0 but got dim={args[1]}" - # TODO: maybe relax? - assert args[2] == 0, f"only support start=0 but got start={args[2]}" - assert args[3] == args[0].size(0), f"only support end == size(0) but got end={args[3]} and size(0)={args[0].size(0)}" - return NF4Tensor( - SubclassTensorArgs( - args[0].size(), - args[0].stride(), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, - args[0].scaler_mean, - args[0].quantized_data, - args[0].nf4, - ) - -@implements( - [ - aten.copy_.default, - ] -) -def nf4_copy_(aten_op, args, kwargs=None): - assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" - # TODO: use original and copy_in in same_meta - original: NF4Tensor = args[0] - copy_in: torch.Tensor = args[1] - - if same_metadata(original, copy_in): - quantized_scalers = aten_op(args[0].quantized_scalers, args[1].quantized_scalers, **kwargs) - quantization_factor = aten_op(args[0].quantization_factor, args[1].quantization_factor, **kwargs) - quantized_data = aten_op(args[0].quantized_data, args[1].quantized_data, **kwargs) - scaler_mean = aten_op(args[0].scaler_mean, args[1].scaler_mean, **kwargs) - nf4 = aten_op(args[0].nf4, args[1].nf4, **kwargs) - tensor_meta = SubclassTensorArgs( - args[1].size(), - args[1].stride(), - args[1].storage_offset(), - args[1].dtype, - args[1].device, - args[1].requires_grad, - ) - return NF4Tensor( - tensor_meta, - args[1].block_size, - args[1].n_blocks, - args[1].scaler_block_size, - quantized_scalers, - quantization_factor, - scaler_mean, - quantized_data, - nf4, - ) - - # Convert Non NF4Tensor into NF4 for copy in - if not isinstance(copy_in, NF4Tensor): - copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size - ) - return original.copy_(copy_in_nf4) - - # Other Tensor is not a NF4Tensor - full_precision = copy_in.get_original_weight() - same_meta_nf4 = NF4Tensor.from_tensor( - full_precision, original.block_size, original.scaler_block_size - ) - return original.copy_(same_meta_nf4) - -@implements( - [ - aten.view.default, - ] -) -def nf4_view(aten_op, args, kwargs=None): - assert len(args) == 2, args[1] == -1 - quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) - quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) - quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs) - tensor_meta = SubclassTensorArgs( - [args[0].numel()], - (1, ), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ) - return NF4Tensor( - tensor_meta, - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - quantized_scalers, - quantization_factor, - args[0].scaler_mean, - quantized_data, - args[0].nf4, - ) -@implements( - [ - aten.as_strided.default, - ] -) -def nf4_as_strided(aten_op, args, kwargs=None): - assert len(args[1]) == 2 and math.prod(args[1]) == args[0].numel(), "only support same numel" - assert args[2] == [args[1][1], 1], f"only support stride {[args[1][1], 1]}" - assert args[0].storage_offset() == args[3], f"only support same storage offset" - return NF4Tensor( - SubclassTensorArgs( - torch.Size(args[1]), - tuple(args[2]), - args[0].storage_offset(), - args[0].dtype, - args[0].device, - args[0].requires_grad, - ), - args[0].block_size, - args[0].n_blocks, - args[0].scaler_block_size, - args[0].quantized_scalers, - args[0].quantization_factor, - args[0].scaler_mean, - args[0].quantized_data, - args[0].nf4, - ) -@implements([torch.ops.aten.detach]) +@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) def noop_detach(func, *args, **kwargs): - assert False return args[0][0] @@ -330,7 +66,6 @@ def t_default(func, *args, **kwargs): a.size(), (a.stride(1), a.stride(0)), a.storage_offset(), - # torch.bits2x4, a.dtype, a.device, a.requires_grad, @@ -354,6 +89,38 @@ def mm_default(func, *args, **kwargs): return linear_nf4(args[0][0], args[0][1]) +@implements( + [ + aten.copy_.default, + ] +) +def copy_(func, *args, **kwargs): + original: NF4Tensor = args[0][0] + copy_in: torch.Tensor = args[0][1] + + # Base Case + + if same_metadata(original, copy_in): + original_tensors = original.__tensor_flatten__()[0] + for tensor_name in original_tensors: + getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) + return + + # Convert Non NF4Tensor into NF4 for copy in + if not isinstance(copy_in, NF4Tensor): + copy_in_nf4 = NF4Tensor.from_tensor( + copy_in, original.block_size, original.scaler_block_size + ) + return original.copy_(copy_in_nf4) + + # Other Tensor is not a NF4Tensor + full_precision = copy_in.get_original_weight() + same_meta_nf4 = NF4Tensor.from_tensor( + full_precision, original.block_size, original.scaler_block_size + ) + return original.copy_(same_meta_nf4) + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -423,7 +190,6 @@ def __new__( tensor_meta.storage_offset, # Picked some floating dtype, but we need dtype extensibility dtype=tensor_meta.dtype, - # dtype=torch.float8_e5m2fnuz, device=tensor_meta.device, requires_grad=tensor_meta.requires_grad, ) @@ -460,11 +226,9 @@ def from_tensor( scaler_block_size: int, ): assert inpt_tensor.dim() <= 2 - # assert inpt_tensor.dtype == torch.bfloat16 assert ( inpt_tensor.numel() % block_size == 0 ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" - # assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not inpt_tensor.requires_grad, "Input tensor must not require grad" @@ -800,66 +564,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) - def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: - return ( - self.quantized_scalers, - self.quantization_factor, - self.quantized_data, - ), ( - SubclassTensorArgs( - self.size(), - self.stride(), - self.storage_offset(), - self.dtype, - self.device, - self.requires_grad, - ), - self.block_size, - self.n_blocks, - self.scaler_block_size, - self.scaler_mean, - self.nf4, - ) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: - (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs - (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4) = metadata - world_size = quantized_data.numel() * 2 // math.prod(tensor_meta.original_shape) - len(tensor_meta.original_shape) == 2, "only support 2D shape" - tensor_meta.original_shape = torch.Size((tensor_meta.original_shape[0] * world_size, tensor_meta.original_shape[1])) - if out is not None: - # TODO: add param dtype - assert isinstance(out, NF4Tensor), f"{type(out)}" - assert ( - quantized_scalers.untyped_storage().data_ptr() - == out.quantized_scalers.untyped_storage().data_ptr() and - quantization_factor.untyped_storage().data_ptr() - == out.quantization_factor.untyped_storage().data_ptr() and - quantized_data.untyped_storage().data_ptr() - == out.quantized_data.untyped_storage().data_ptr() - ), f"Expects out's data to be the all-gather output" - return - - return NF4Tensor( - tensor_meta, - block_size, - n_blocks, - scaler_block_size, - quantized_scalers, - quantization_factor, - scaler_mean, - quantized_data, - nf4, - ), (quantized_scalers, quantization_factor, quantized_data) - - class LinearNF4(torch.autograd.Function): @staticmethod @@ -892,3 +596,17 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) + + +NF4_TORCH_FUNCTIONS = {} + +def implements_torch_function(torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + NF4_TORCH_FUNCTIONS[torch_function] = func + return func + return decorator + +@implements_torch_function(torch.Tensor.to) +def function_to_dtype(*args, **kwargs): + return args[0].get_original_weight().to(args[1]) From e15d244fc44ac9af2fac623a22e6fce6d6dcc80a Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 14:11:00 -0700 Subject: [PATCH 10/17] fix linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index efb327af8..8f7afbee3 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,6 +1,6 @@ +import functools from dataclasses import dataclass from typing import Dict, Tuple -import functools import torch import torch.nn.functional as F @@ -11,7 +11,7 @@ c10d_functional = torch.ops.c10d_functional -from typing import Any, Optional, Tuple, Union, List +from typing import Any, Tuple NF4_OPS_TABLE: Dict[Any, Any] = {} @@ -37,7 +37,6 @@ def decorator(func): return decorator - @implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) def noop_detach(func, *args, **kwargs): return args[0][0] @@ -549,7 +548,6 @@ def allowed_subclasses(type): f"NF4Tensor dispatch: attempting to run {func}, this is not supported" ) - # Do not force the Float8Tensor type on the returned tensor @classmethod @@ -600,13 +598,16 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): NF4_TORCH_FUNCTIONS = {} + def implements_torch_function(torch_function): def decorator(func): functools.update_wrapper(func, torch_function) NF4_TORCH_FUNCTIONS[torch_function] = func return func + return decorator + @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): return args[0].get_original_weight().to(args[1]) From d4beb8fff9e116c22e4570892f0a94506d112002 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 14:43:25 -0700 Subject: [PATCH 11/17] use torch testing @parametrize Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 54 +++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 2637a655e..e599e7b06 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -3,7 +3,12 @@ import torch from torch import nn -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4 import torch.nn.functional as F import io @@ -45,19 +50,19 @@ def _build_bnb_linear(input_weight, device): bnb_linear.to(device) return bnb_linear -class TestMod(nn.Module): - def __init__(self, tensor, block_size, scaler_block_size): - super().__init__() - self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size)) +class TestNF4Linear(TestCase): + class TestMod(nn.Module): + def __init__(self, tensor, block_size, scaler_block_size): + super().__init__() + self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size)) -class TestNF4Linear(): def save_state_dict_to_buffer(self, state_dict: OrderedDict): buffer = io.BytesIO() torch.save(state_dict, buffer) buffer.seek(0) return buffer - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_register_nf4_as_param(self, dtype: torch.dtype): nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) @@ -66,7 +71,7 @@ def test_register_nf4_as_param(self, dtype: torch.dtype): param = torch.nn.Parameter(nf4_tensor, requires_grad=False) assert not param.requires_grad - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_output_dtype_match(self, dtype:torch.dtype): # Test to ensure W4 A16 produces A16 inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) @@ -74,7 +79,7 @@ def test_output_dtype_match(self, dtype:torch.dtype): out = linear_nf4(input=inp, weight=nf4_tensor) assert out.dtype == dtype - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_backward_dtype_match(self, dtype:torch.dtype): # Test to ensure backward pass gives activation a bf16 gradient and no gradient # to the linear's weight, as it is frozen. @@ -86,7 +91,7 @@ def test_backward_dtype_match(self, dtype:torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 torch.manual_seed(0) @@ -108,7 +113,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): """ This test ensures that nf4_linear is "no worse" than BNB by ensuring the @@ -133,11 +138,11 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype): assert err_bnb < 0.5 * dim @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_state_dicts(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) - base_mod = TestMod(inpt_tensor, 32, 2) + base_mod = self.TestMod(inpt_tensor, 32, 2) dummy_dict = {"param": inpt_tensor} base_mod.load_state_dict(dummy_dict) @@ -146,34 +151,34 @@ def test_load_from_state_dicts(self, dtype: torch.dtype): assert base_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_nf4_same_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) - base_mod = TestMod(inpt_tensor, 32, 2) + base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) - other_mod = TestMod(inpt_tensor, 32, 2) + other_mod = self.TestMod(inpt_tensor, 32, 2) other_mod.load_state_dict(torch.load(saved_state_dict)) assert other_mod.param.block_size == 32 assert other_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_load_from_nf4_diff_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" inpt_tensor = torch.rand(128, device='cuda', dtype=dtype) - base_mod = TestMod(inpt_tensor, 32, 2) + base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) - other_mod = TestMod(inpt_tensor, 64, 1) + other_mod = self.TestMod(inpt_tensor, 64, 1) other_mod.load_state_dict(torch.load(saved_state_dict)) assert other_mod.param.block_size == 64 assert other_mod.param.scaler_block_size == 1 - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_to_copy(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, device='cpu') inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) @@ -186,7 +191,7 @@ def test_to_copy(self, dtype: torch.dtype): nf4_to_dtype = inpt_tensor_nf4.to(dtype) torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_to_dtype(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, dtype=dtype) inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) @@ -195,7 +200,7 @@ def test_to_dtype(self, dtype: torch.dtype): assert inpt_tensor_nf4.to(dtype).dtype == dtype @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_smoketest_linear(self, dtype: torch.dtype): a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) @@ -204,7 +209,7 @@ def test_smoketest_linear(self, dtype: torch.dtype): out2 = torch.nn.functional.linear(inp, a_nf4) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_smoketest_linear_compile(self, dtype: torch.dtype): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: self.skipTest("test requires SM capability of at least (8, 0).") @@ -214,6 +219,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) +instantiate_parametrized_tests(TestNF4Linear) if __name__ == "__main__": - unittest.main() + run_tests() From f41cb3da4d3a8a6addcaa45ec20e4236654fca4e Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 14:47:08 -0700 Subject: [PATCH 12/17] remove unused import Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index e599e7b06..0732b6974 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -14,7 +14,6 @@ import io from collections import OrderedDict import torchao -import pytest bnb_available = False From 950d9fda480a7383e2564989e059b27858426b38 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 16:05:12 -0700 Subject: [PATCH 13/17] sm8 for fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 0732b6974..12c885f50 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -210,7 +210,7 @@ def test_smoketest_linear(self, dtype: torch.dtype): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_smoketest_linear_compile(self, dtype: torch.dtype): - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype in [torch.bfloat16, torch.float16]: self.skipTest("test requires SM capability of at least (8, 0).") a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) From d4eae0b951e683922c13c40a06291319cf1f6eac Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 17 Apr 2024 17:18:09 -0700 Subject: [PATCH 14/17] remove sm check for fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 12c885f50..0732b6974 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -210,7 +210,7 @@ def test_smoketest_linear(self, dtype: torch.dtype): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_smoketest_linear_compile(self, dtype: torch.dtype): - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype in [torch.bfloat16, torch.float16]: + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: self.skipTest("test requires SM capability of at least (8, 0).") a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) From 9444f2c698aa13de657f22fd975a94317ffab6d9 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 18 Apr 2024 16:12:13 -0700 Subject: [PATCH 15/17] skip 2.2.2 and below for tracing tensor subclass Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 0732b6974..39818a9df 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -1,5 +1,6 @@ import logging import unittest +from packaging import version import torch from torch import nn @@ -15,6 +16,7 @@ from collections import OrderedDict import torchao + bnb_available = False try: @@ -212,6 +214,8 @@ def test_smoketest_linear(self, dtype: torch.dtype): def test_smoketest_linear_compile(self, dtype: torch.dtype): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: self.skipTest("test requires SM capability of at least (8, 0).") + if version.parse(torch.__version__) <= version.parse("2.2.2"): + self.skipTest("test requires 2.3.0+ for tracing NF4Tensor") a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) From 9be2de31478c2d63226ab389ceedec795a3ac85b Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 18 Apr 2024 17:02:40 -0700 Subject: [PATCH 16/17] include kwargs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 4 ++-- torchao/dtypes/nf4tensor.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 39818a9df..e3b25e3c3 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -214,8 +214,8 @@ def test_smoketest_linear(self, dtype: torch.dtype): def test_smoketest_linear_compile(self, dtype: torch.dtype): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: self.skipTest("test requires SM capability of at least (8, 0).") - if version.parse(torch.__version__) <= version.parse("2.2.2"): - self.skipTest("test requires 2.3.0+ for tracing NF4Tensor") + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor") a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 8f7afbee3..93d6e3be5 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -610,4 +610,12 @@ def decorator(func): @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): - return args[0].get_original_weight().to(args[1]) + if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype): + # Tensor.to(dtype, non_blocking, copy, memory_format) + return args[0].get_original_weight().to(*args[1:], **kwargs) + else: + # Tensor.to(device, dtype, non_blocking, copy, memory_format) + # Tensor.to(other, non_blocking, copy) + raise NotImplementedError( + f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported" + ) From 29813931b8d11ea901d3ad59bb30920997970e51 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 18 Apr 2024 17:20:49 -0700 Subject: [PATCH 17/17] raise unimplemented Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 93d6e3be5..886eb6c0a 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -555,8 +555,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func in NF4_TORCH_FUNCTIONS: - return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + try: + if func in NF4_TORCH_FUNCTIONS: + return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + except NotImplementedError: + pass with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -617,5 +620,5 @@ def function_to_dtype(*args, **kwargs): # Tensor.to(device, dtype, non_blocking, copy, memory_format) # Tensor.to(other, non_blocking, copy) raise NotImplementedError( - f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported" + f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch" )