From ac0dd2b1d24cb49b906ea7e9c1a3771786c021e7 Mon Sep 17 00:00:00 2001 From: "Wei (Will) Feng" <134637289+weifengpy@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:13:45 -0700 Subject: [PATCH] [FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 (#118) --- test/dtypes/test_nf4.py | 116 +++++++++++++++++++++--------------- torchao/dtypes/nf4tensor.py | 55 +++++++++++++---- 2 files changed, 113 insertions(+), 58 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index c71fcd25b..e3b25e3c3 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -1,15 +1,22 @@ import logging import unittest +from packaging import version import torch from torch import nn -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4 import torch.nn.functional as F import io from collections import OrderedDict import torchao + bnb_available = False try: @@ -24,10 +31,10 @@ ) -def _build_input_weight(embed_dim: int, device: torch.device): +def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype): torch.manual_seed(0) input_weight = torch.empty( - embed_dim, embed_dim, device=device, dtype=torch.bfloat16 + embed_dim, embed_dim, device=device, dtype=dtype ) input_weight.normal_(0, 1) return input_weight @@ -44,7 +51,6 @@ def _build_bnb_linear(input_weight, device): bnb_linear.to(device) return bnb_linear - class TestNF4Linear(TestCase): class TestMod(nn.Module): def __init__(self, tensor, block_size, scaler_block_size): @@ -57,42 +63,46 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict): buffer.seek(0) return buffer - def test_register_nf4_as_param(self): - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_register_nf4_as_param(self, dtype: torch.dtype): + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) # Would raise if nn.Parameter registration fails, such as no detach() # impl when calling __torch_dispatch__ param = torch.nn.Parameter(nf4_tensor, requires_grad=False) assert not param.requires_grad - def test_output_bf16(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_output_dtype_match(self, dtype:torch.dtype): # Test to ensure W4 A16 produces A16 - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) out = linear_nf4(input=inp, weight=nf4_tensor) - assert out.dtype == torch.bfloat16 + assert out.dtype == dtype - def test_backward_bf16(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_backward_dtype_match(self, dtype:torch.dtype): # Test to ensure backward pass gives activation a bf16 gradient and no gradient # to the linear's weight, as it is frozen. - nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16)) - inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True) + nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype)) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) linear_nf4(inp, nf4_tensor).sum().backward() - assert inp.grad is not None and inp.grad.dtype == torch.bfloat16 + assert inp.grad is not None and inp.grad.dtype == dtype assert nf4_tensor.grad is None @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_reconstruction_qlora_vs_bnb(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 torch.manual_seed(0) device = "cuda" embed_dim = 512 - input_weight = _build_input_weight(embed_dim, device) + input_weight = _build_input_weight(embed_dim, device, dtype) nf4_weight = to_nf4(input_weight) bnb_linear = _build_bnb_linear(input_weight, device) bnb_reconstruction = bnb_linear( - torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device) + torch.eye(embed_dim, embed_dim, dtype=dtype, device=device) ) bnb_diff = (bnb_reconstruction.T - input_weight).abs().max() nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max() @@ -104,7 +114,8 @@ def test_reconstruction_qlora_vs_bnb(self): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_nf4_bnb_linear(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_nf4_bnb_linear(self, dtype: torch.dtype): """ This test ensures that nf4_linear is "no worse" than BNB by ensuring the error compared to a bf16 linear is not more than BNB's implementation. @@ -112,11 +123,11 @@ def test_nf4_bnb_linear(self): torch.manual_seed(0) dim = 512 device = "cuda" - input_weight = _build_input_weight(dim, device) + input_weight = _build_input_weight(dim, device, dtype) nf4_weight = to_nf4(input_weight) bnb_linear = _build_bnb_linear(input_weight, device) - inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda") + inp = torch.randn(2, 512, dtype=dtype, device="cuda") out_nf4 = linear_nf4(inp, nf4_weight).sum() out_bnb = bnb_linear(inp).sum() @@ -128,21 +139,23 @@ def test_nf4_bnb_linear(self): assert err_bnb < 0.5 * dim @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_bfloat16(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_load_from_state_dicts(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) - bf16_dummy_dict = {"param": inpt_tensor} - base_mod.load_state_dict(bf16_dummy_dict) + dummy_dict = {"param": inpt_tensor} + base_mod.load_state_dict(dummy_dict) assert base_mod.param.block_size == 32 assert base_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_nf4_same_meta(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_load_from_nf4_same_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(64, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) @@ -153,9 +166,10 @@ def test_load_from_nf4_same_meta(self): assert other_mod.param.scaler_block_size == 2 @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") - def test_load_from_nf4_diff_meta(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_load_from_nf4_diff_meta(self, dtype: torch.dtype): """Tests loading to and from different module state dicts""" - inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16) + inpt_tensor = torch.rand(128, device='cuda', dtype=dtype) base_mod = self.TestMod(inpt_tensor, 32, 2) state_dict = base_mod.state_dict() saved_state_dict = self.save_state_dict_to_buffer(state_dict) @@ -165,44 +179,50 @@ def test_load_from_nf4_diff_meta(self): assert other_mod.param.block_size == 64 assert other_mod.param.scaler_block_size == 1 - def test_to_copy(self): + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_to_copy(self, dtype: torch.dtype): inpt_tensor = torch.rand(128, device='cpu') inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) - inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16) - torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13) + nf4_to_dtype = inpt_tensor_nf4.to(dtype) + torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) if torch.cuda.is_available(): inpt_tensor = torch.rand(128, device='cuda') inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) - inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16) - torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13) + nf4_to_dtype = inpt_tensor_nf4.to(dtype) + torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13) - def test_to_bfloat16(self): - inpt_tensor = torch.rand(128, dtype=torch.bfloat16) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_to_dtype(self, dtype: torch.dtype): + inpt_tensor = torch.rand(128, dtype=dtype) inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2) assert type(inpt_tensor_nf4) != torch.Tensor - assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor - assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16 + assert type(inpt_tensor_nf4.to(dtype)) == torch.Tensor + assert inpt_tensor_nf4.to(dtype).dtype == dtype @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_smoketest_linear(self): - a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda') + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_smoketest_linear(self, dtype: torch.dtype): + a = torch.randn(32, 32, dtype=dtype, device='cuda') a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) out1 = torch.nn.functional.linear(inp, a) out2 = torch.nn.functional.linear(inp, a_nf4) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_smoketest_linear_compile(self): - for dtype in [torch.bfloat16, torch.float16]: - if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: - self.skipTest("test requires SM capability of at least (8, 0).") - a = torch.randn(32, 32, dtype=dtype, device='cuda') - a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) - inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) - out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) + @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_smoketest_linear_compile(self, dtype: torch.dtype): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16: + self.skipTest("test requires SM capability of at least (8, 0).") + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor") + a = torch.randn(32, 32, dtype=dtype, device='cuda') + a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) + inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) + out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) +instantiate_parametrized_tests(TestNF4Linear) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index ea45a6c0d..886eb6c0a 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -1,3 +1,4 @@ +import functools from dataclasses import dataclass from typing import Dict, Tuple @@ -10,7 +11,7 @@ c10d_functional = torch.ops.c10d_functional -from typing import Any +from typing import Any, Tuple NF4_OPS_TABLE: Dict[Any, Any] = {} @@ -64,7 +65,7 @@ def t_default(func, *args, **kwargs): a.size(), (a.stride(1), a.stride(0)), a.storage_offset(), - torch.bits2x4, + a.dtype, a.device, a.requires_grad, ) @@ -187,7 +188,7 @@ def __new__( tensor_meta.original_strides, tensor_meta.storage_offset, # Picked some floating dtype, but we need dtype extensibility - dtype=torch.float8_e5m2fnuz, + dtype=tensor_meta.dtype, device=tensor_meta.device, requires_grad=tensor_meta.requires_grad, ) @@ -224,11 +225,9 @@ def from_tensor( scaler_block_size: int, ): assert inpt_tensor.dim() <= 2 - assert inpt_tensor.dtype == torch.bfloat16 assert ( inpt_tensor.numel() % block_size == 0 ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" - assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16" assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!" # I think I want do this # assert not inpt_tensor.requires_grad, "Input tensor must not require grad" @@ -254,7 +253,7 @@ def from_tensor( 1.0000, ], device=device, - dtype=torch.bfloat16, + dtype=inpt_tensor.dtype, ) n_blocks = inpt_tensor.numel() // block_size # Double quantization @@ -370,7 +369,7 @@ def dequantize_scalers( n_scaler_blocks = inpt_tensor.numel() // scaler_block_size inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size) dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to( - torch.bfloat16 + self.dtype ) + self.scaler_mean return dequantized @@ -551,7 +550,19 @@ def allowed_subclasses(type): # Do not force the Float8Tensor type on the returned tensor - __torch_function__ = torch._C._disabled_torch_function_impl + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + try: + if func in NF4_TORCH_FUNCTIONS: + return NF4_TORCH_FUNCTIONS[func](*args, **kwargs) + except NotImplementedError: + pass + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) class LinearNF4(torch.autograd.Function): @@ -585,5 +596,29 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): - tensor1 = tensor.to(torch.bfloat16) - return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size) + return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) + + +NF4_TORCH_FUNCTIONS = {} + + +def implements_torch_function(torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + NF4_TORCH_FUNCTIONS[torch_function] = func + return func + + return decorator + + +@implements_torch_function(torch.Tensor.to) +def function_to_dtype(*args, **kwargs): + if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype): + # Tensor.to(dtype, non_blocking, copy, memory_format) + return args[0].get_original_weight().to(*args[1:], **kwargs) + else: + # Tensor.to(device, dtype, non_blocking, copy, memory_format) + # Tensor.to(other, non_blocking, copy) + raise NotImplementedError( + f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch" + )