Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][1/n] construct NF4Tensor from bf16/fp16/fp32 #118

Merged
merged 22 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import logging
import unittest
from packaging import version

import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
import torch.nn.functional as F
import io
from collections import OrderedDict
import torchao


bnb_available = False

try:
Expand All @@ -24,10 +31,10 @@
)


def _build_input_weight(embed_dim: int, device: torch.device):
def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype):
torch.manual_seed(0)
input_weight = torch.empty(
embed_dim, embed_dim, device=device, dtype=torch.bfloat16
embed_dim, embed_dim, device=device, dtype=dtype
)
input_weight.normal_(0, 1)
return input_weight
Expand All @@ -44,7 +51,6 @@ def _build_bnb_linear(input_weight, device):
bnb_linear.to(device)
return bnb_linear

weifengpy marked this conversation as resolved.
Show resolved Hide resolved

class TestNF4Linear(TestCase):
class TestMod(nn.Module):
def __init__(self, tensor, block_size, scaler_block_size):
Expand All @@ -57,42 +63,46 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict):
buffer.seek(0)
return buffer

def test_register_nf4_as_param(self):
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_register_nf4_as_param(self, dtype: torch.dtype):
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))

# Would raise if nn.Parameter registration fails, such as no detach()
# impl when calling __torch_dispatch__
param = torch.nn.Parameter(nf4_tensor, requires_grad=False)
assert not param.requires_grad

def test_output_bf16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_output_dtype_match(self, dtype:torch.dtype):
# Test to ensure W4 A16 produces A16
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))
out = linear_nf4(input=inp, weight=nf4_tensor)
assert out.dtype == torch.bfloat16
assert out.dtype == dtype

def test_backward_bf16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_backward_dtype_match(self, dtype:torch.dtype):
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
# to the linear's weight, as it is frozen.
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=dtype))
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
linear_nf4(inp, nf4_tensor).sum().backward()
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
assert inp.grad is not None and inp.grad.dtype == dtype
assert nf4_tensor.grad is None

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_reconstruction_qlora_vs_bnb(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
# From https:/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
torch.manual_seed(0)
device = "cuda"
embed_dim = 512
input_weight = _build_input_weight(embed_dim, device)
input_weight = _build_input_weight(embed_dim, device, dtype)
nf4_weight = to_nf4(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)
bnb_reconstruction = bnb_linear(
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
torch.eye(embed_dim, embed_dim, dtype=dtype, device=device)
)
bnb_diff = (bnb_reconstruction.T - input_weight).abs().max()
nugs_diff = (nf4_weight.get_original_weight() - input_weight).abs().max()
Expand All @@ -104,19 +114,20 @@ def test_reconstruction_qlora_vs_bnb(self):

@unittest.skipIf(not bnb_available, "Need bnb availble")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_nf4_bnb_linear(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_nf4_bnb_linear(self, dtype: torch.dtype):
"""
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
error compared to a bf16 linear is not more than BNB's implementation.
"""
torch.manual_seed(0)
dim = 512
device = "cuda"
input_weight = _build_input_weight(dim, device)
input_weight = _build_input_weight(dim, device, dtype)
nf4_weight = to_nf4(input_weight)
bnb_linear = _build_bnb_linear(input_weight, device)

inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
inp = torch.randn(2, 512, dtype=dtype, device="cuda")

out_nf4 = linear_nf4(inp, nf4_weight).sum()
out_bnb = bnb_linear(inp).sum()
Expand All @@ -128,21 +139,23 @@ def test_nf4_bnb_linear(self):
assert err_bnb < 0.5 * dim

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_bfloat16(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_state_dicts(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)

bf16_dummy_dict = {"param": inpt_tensor}
base_mod.load_state_dict(bf16_dummy_dict)
dummy_dict = {"param": inpt_tensor}
base_mod.load_state_dict(dummy_dict)

assert base_mod.param.block_size == 32
assert base_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_nf4_same_meta(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -153,9 +166,10 @@ def test_load_from_nf4_same_meta(self):
assert other_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_load_from_nf4_diff_meta(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16)
inpt_tensor = torch.rand(128, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -165,44 +179,50 @@ def test_load_from_nf4_diff_meta(self):
assert other_mod.param.block_size == 64
assert other_mod.param.scaler_block_size == 1

def test_to_copy(self):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_copy(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, device='cpu')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

if torch.cuda.is_available():
inpt_tensor = torch.rand(128, device='cuda')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

def test_to_bfloat16(self):
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_dtype(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, dtype=dtype)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16
assert type(inpt_tensor_nf4.to(dtype)) == torch.Tensor
assert inpt_tensor_nf4.to(dtype).dtype == dtype

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_smoketest_linear(self):
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_smoketest_linear(self, dtype: torch.dtype):
a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out1 = torch.nn.functional.linear(inp, a)
out2 = torch.nn.functional.linear(inp, a_nf4)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_smoketest_linear_compile(self):
for dtype in [torch.bfloat16, torch.float16]:
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
Copy link
Contributor Author

@weifengpy weifengpy Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_smoketest_linear_compile is always skipped before because we exit with self.skipTest with torch.bfloat16 and did not have a chance to test torch.float16. It is fixed in this version by using @parameterize

a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_smoketest_linear_compile(self, dtype: torch.dtype):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor")
Copy link
Contributor Author

@weifengpy weifengpy Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starting from 2.3.0 we can trace subclass when inner tensors have different shapes than outer wrapper class. Specifically, we use symbolic_context.inner_contexts instead of symbolic_context from outer wrapper class: https:/pytorch/pytorch/blob/main/torch/_subclasses/meta_utils.py#L649

a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


instantiate_parametrized_tests(TestNF4Linear)

if __name__ == "__main__":
unittest.main()
run_tests()
55 changes: 45 additions & 10 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from dataclasses import dataclass
from typing import Dict, Tuple

Expand All @@ -10,7 +11,7 @@

c10d_functional = torch.ops.c10d_functional

from typing import Any
from typing import Any, Tuple

NF4_OPS_TABLE: Dict[Any, Any] = {}

Expand Down Expand Up @@ -64,7 +65,7 @@ def t_default(func, *args, **kwargs):
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
torch.bits2x4,
a.dtype,
a.device,
a.requires_grad,
)
Expand Down Expand Up @@ -187,7 +188,7 @@ def __new__(
tensor_meta.original_strides,
tensor_meta.storage_offset,
# Picked some floating dtype, but we need dtype extensibility
dtype=torch.float8_e5m2fnuz,
dtype=tensor_meta.dtype,
device=tensor_meta.device,
requires_grad=tensor_meta.requires_grad,
)
Expand Down Expand Up @@ -224,11 +225,9 @@ def from_tensor(
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2
assert inpt_tensor.dtype == torch.bfloat16
assert (
inpt_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
assert inpt_tensor.dtype == torch.bfloat16, "Input tensor must be bfloat16"
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
# assert not inpt_tensor.requires_grad, "Input tensor must not require grad"
Expand All @@ -254,7 +253,7 @@ def from_tensor(
1.0000,
],
device=device,
dtype=torch.bfloat16,
dtype=inpt_tensor.dtype,
)
n_blocks = inpt_tensor.numel() // block_size
# Double quantization
Expand Down Expand Up @@ -370,7 +369,7 @@ def dequantize_scalers(
n_scaler_blocks = inpt_tensor.numel() // scaler_block_size
inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
torch.bfloat16
self.dtype
) + self.scaler_mean
return dequantized

Expand Down Expand Up @@ -551,7 +550,19 @@ def allowed_subclasses(type):

# Do not force the Float8Tensor type on the returned tensor

__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

try:
if func in NF4_TORCH_FUNCTIONS:
return NF4_TORCH_FUNCTIONS[func](*args, **kwargs)
except NotImplementedError:
pass

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


class LinearNF4(torch.autograd.Function):
Expand Down Expand Up @@ -585,5 +596,29 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:


def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
tensor1 = tensor.to(torch.bfloat16)
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)


NF4_TORCH_FUNCTIONS = {}


def implements_torch_function(torch_function):
def decorator(func):
functools.update_wrapper(func, torch_function)
NF4_TORCH_FUNCTIONS[torch_function] = func
return func

return decorator


@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return args[0].get_original_weight().to(*args[1:], **kwargs)
else:
# Tensor.to(device, dtype, non_blocking, copy, memory_format)
# Tensor.to(other, non_blocking, copy)
raise NotImplementedError(
f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch"
)
Loading