Skip to content

Commit

Permalink
[FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 (pytorch#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
weifengpy authored Apr 19, 2024
1 parent b61eb2f commit ac0dd2b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 58 deletions.
116 changes: 68 additions & 48 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import logging
import unittest
from packaging import version

import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao


bnb_available = False

try:
Expand All @@ -24,10 +31,10 @@
)


def _build_input_weight(embed_dim: int, device: torch.device):
def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype):
torch.manual_seed(0)
input_weight = torch.empty(
embed_dim, embed_dim, device=device, dtype=torch.bfloat16
embed_dim, embed_dim, device=device, dtype=dtype
)
input_weight.normal_(0, 1)
return input_weight
Expand All @@ -44,7 +51,6 @@ def _build_bnb_linear(input_weight, device):
bnb_linear.to(device)
return bnb_linear


class TestNF4Linear(TestCase):
class TestMod(nn.Module):
def __init__(self, tensor, block_size, scaler_block_size):
Expand All @@ -57,42 +63,46 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict):
buffer.seek(0)
return buffer

def test_register_nf4_as_param(self):
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_register_nf4_as_param(self, dtype: torch.dtype):
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))

# Would raise if nn.Parameter registration fails, such as no detach()
# impl when calling __torch_dispatch__
param = torch.nn.Parameter(nf4_tensor, requires_grad=False)
assert not param.requires_grad

def test_output_bf16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_output_dtype_match(self, dtype:torch.dtype):
# Test to ensure W4 A16 produces A16
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))
out = linear_nf4(input=inp, weight=nf4_tensor)
assert out.dtype == torch.bfloat16
assert out.dtype == dtype

def test_backward_bf16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_backward_dtype_match(self, dtype:torch.dtype):
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
# to the linear's weight, as it is frozen.
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
linear_nf4(inp, nf4_tensor).sum().backward()
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
assert inp.grad is not None and inp.grad.dtype == dtype
assert nf4_tensor.grad is None

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_reconstruction_qlora_vs_bnb(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
# From https:/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
torch.manual_seed(0)
device = "cuda"
embed_dim = 512
input_weight = _build_input_weight(embed_dim, device)
input_weight = _build_input_weight(embed_dim, device, dtype)
nf4_weight = to_nf4(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)
bnb_reconstruction = bnb_linear(
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
torch.eye(embed_dim, embed_dim, dtype=dtype, device=device)
)
bnb_diff = (bnb_reconstruction.T - input_weight).abs().max()
nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max()
Expand All @@ -104,19 +114,20 @@ def test_reconstruction_qlora_vs_bnb(self):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_nf4_bnb_linear(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_nf4_bnb_linear(self, dtype: torch.dtype):
"""
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
error compared to a bf16 linear is not more than BNB's implementation.
"""
torch.manual_seed(0)
dim = 512
device = "cuda"
input_weight = _build_input_weight(dim, device)
input_weight = _build_input_weight(dim, device, dtype)
nf4_weight = to_nf4(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)

inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
inp = torch.randn(2, 512, dtype=dtype, device="cuda")

out_nf4 = linear_nf4(inp, nf4_weight).sum()
out_bnb = bnb_linear(inp).sum()
Expand All @@ -128,21 +139,23 @@ def test_nf4_bnb_linear(self):
assert err_bnb < 0.5 * dim

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_bfloat16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_state_dicts(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)

bf16_dummy_dict = {"param": inpt_tensor}
base_mod.load_state_dict(bf16_dummy_dict)
dummy_dict = {"param": inpt_tensor}
base_mod.load_state_dict(dummy_dict)

assert base_mod.param.block_size == 32
assert base_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_nf4_same_meta(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -153,9 +166,10 @@ def test_load_from_nf4_same_meta(self):
assert other_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_nf4_diff_meta(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(128, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -165,44 +179,50 @@ def test_load_from_nf4_diff_meta(self):
assert other_mod.param.block_size == 64
assert other_mod.param.scaler_block_size == 1

def test_to_copy(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_copy(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, device='cpu')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

if torch.cuda.is_available():
inpt_tensor = torch.rand(128, device='cuda')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

def test_to_bfloat16(self):
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_dtype(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, dtype=dtype)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16
assert type(inpt_tensor_nf4.to(dtype)) == torch.Tensor
assert inpt_tensor_nf4.to(dtype).dtype == dtype

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_smoketest_linear(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_smoketest_linear(self, dtype: torch.dtype):
a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out1 = torch.nn.functional.linear(inp, a)
out2 = torch.nn.functional.linear(inp, a_nf4)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_smoketest_linear_compile(self):
for dtype in [torch.bfloat16, torch.float16]:
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_smoketest_linear_compile(self, dtype: torch.dtype):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor")
a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


instantiate_parametrized_tests(TestNF4Linear)

if __name__ == "__main__":
unittest.main()
run_tests()
55 changes: 45 additions & 10 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from dataclasses import dataclass
from typing import Dict, Tuple

Expand All @@ -10,7 +11,7 @@

c10d_functional = torch.ops.c10d_functional

from typing import Any
from typing import Any, Tuple

NF4_OPS_TABLE: Dict[Any, Any] = {}

Expand Down Expand Up @@ -64,7 +65,7 @@ def t_default(func, *args, **kwargs):
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.dtype,
a.device,
a.requires_grad,
)
Expand Down Expand Up @@ -187,7 +188,7 @@ def __new__(
tensor_meta.original_strides,
tensor_meta.storage_offset,
# Picked some floating dtype, but we need dtype extensibility
dtype=torch.float8_e5m2fnuz,
dtype=tensor_meta.dtype,
device=tensor_meta.device,
requires_grad=tensor_meta.requires_grad,
)
Expand Down Expand Up @@ -224,11 +225,9 @@ def from_tensor(
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2
assert inpt_tensor.dtype == torch.bfloat16
assert (
inpt_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16"
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
# assert not inpt_tensor.requires_grad, "Input tensor must not require grad"
Expand All @@ -254,7 +253,7 @@ def from_tensor(
1.0000,
],
device=device,
dtype=torch.bfloat16,
dtype=inpt_tensor.dtype,
)
n_blocks = inpt_tensor.numel() // block_size
# Double quantization
Expand Down Expand Up @@ -370,7 +369,7 @@ def dequantize_scalers(
n_scaler_blocks = inpt_tensor.numel() // scaler_block_size
inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
torch.bfloat16
self.dtype
) + self.scaler_mean
return dequantized

Expand Down Expand Up @@ -551,7 +550,19 @@ def allowed_subclasses(type):

# Do not force the Float8Tensor type on the returned tensor

__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

try:
if func in NF4_TORCH_FUNCTIONS:
return NF4_TORCH_FUNCTIONS[func](*args, **kwargs)
except NotImplementedError:
pass

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


class LinearNF4(torch.autograd.Function):
Expand Down Expand Up @@ -585,5 +596,29 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:


def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
tensor1 = tensor.to(torch.bfloat16)
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)


NF4_TORCH_FUNCTIONS = {}


def implements_torch_function(torch_function):
def decorator(func):
functools.update_wrapper(func, torch_function)
NF4_TORCH_FUNCTIONS[torch_function] = func
return func

return decorator


@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return args[0].get_original_weight().to(*args[1:], **kwargs)
else:
# Tensor.to(device, dtype, non_blocking, copy, memory_format)
# Tensor.to(other, non_blocking, copy)
raise NotImplementedError(
f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch"
)

0 comments on commit ac0dd2b

Please sign in to comment.