Skip to content

Commit

Permalink
Condense formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Mar 13, 2024
1 parent fd723b7 commit aba3a96
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 508 deletions.
20 changes: 5 additions & 15 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,12 @@ def backward(ctx, grad_output):
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
Expand Down Expand Up @@ -201,9 +193,7 @@ def backward(ctx, grad_output):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
Expand Down
142 changes: 36 additions & 106 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,12 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
)
raise ValueError(f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}")
else:
return getattr(lib, name)


def get_transform_buffer(
shape, dtype, device, to_order, from_order="row", transpose=False
):
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty
init_func = torch.zeros
dims = len(shape)
Expand Down Expand Up @@ -508,9 +504,7 @@ def nvidia_transform(
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1]
)
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
else:
new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)
Expand Down Expand Up @@ -1018,7 +1012,7 @@ def quantize_4bit(
del absmax
state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2)
else:
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, )
state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type)

return out, state

Expand Down Expand Up @@ -1421,9 +1415,7 @@ def optimizer_update_8bit(
ct.c_int32(g.numel()),
)
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
post_call(prev_device)


Expand Down Expand Up @@ -1458,9 +1450,7 @@ def optimizer_update_8bit_blockwise(
len(str2optimizer8bit_blockwise[optimizer_name])==3):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
post_call(prev_device)

is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
Expand All @@ -1487,9 +1477,7 @@ def optimizer_update_8bit_blockwise(
)
post_call(prev_device)

def percentile_clipping(
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
):
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
"""Applies percentile clipping
grad: torch.Tensor
Expand Down Expand Up @@ -1531,9 +1519,7 @@ def percentile_clipping(
return current_gnorm, clip_value, gnorm_scale


def histogram_scatter_add_2d(
histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
):
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
assert source.dtype == torch.float32
Expand All @@ -1553,9 +1539,7 @@ def histogram_scatter_add_2d(
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
)
raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}")

sA = A.shape
sB = B.shape
Expand Down Expand Up @@ -1633,9 +1617,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
sout = (sA[0], sA[1], sB[1])

if not correct:
raise ValueError(
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
)
raise ValueError(f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.")

return sout

Expand Down Expand Up @@ -1763,9 +1745,7 @@ def igemm(
# special case
assert len(sA) == 3
if not (sA[0] == sB[0] and sA[1] == sB[1]):
raise ValueError(
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
)
raise ValueError(f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}")

transposed_A = True
transposed_B = False
Expand Down Expand Up @@ -1796,9 +1776,7 @@ def batched_igemm(
transposed_B=False,
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(
f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
)
raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}")
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
Expand Down Expand Up @@ -1892,13 +1870,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)

if dimsA == 2 and out is None:
out, Sout = get_transform_buffer(
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
)
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row")
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer(
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
)
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")

assert dimsB != 3, "len(B.shape)==3 not supported"
assert A.device.type == "cuda"
Expand Down Expand Up @@ -1942,22 +1916,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_turing_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == "col_ampere":
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_ampere_8(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)

if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
Expand Down Expand Up @@ -1990,19 +1956,11 @@ def mm_dequant(
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
new_row_stats = torch.empty(
out_shape[0], dtype=torch.float32, device=A.device
)
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(
out_shape[1], dtype=torch.float32, device=A.device
)
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
assert (
new_col_stats.shape[0] == col_stats.shape[0]
), f"{new_col_stats.shape} vs {col_stats.shape}"
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"

prev_device = pre_call(A.device)
ptrA = get_ptr(A)
Expand All @@ -2022,9 +1980,7 @@ def mm_dequant(
return out


def get_colrow_absmax(
A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
):
def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
assert A.dtype == torch.float16
device = A.device

Expand All @@ -2037,18 +1993,12 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
row_stats = torch.empty(
(rows,), dtype=torch.float32, device=device
).fill_(-50000.0)
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
if col_stats is None:
col_stats = torch.empty(
(cols,), dtype=torch.float32, device=device
).fill_(-50000.0)
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)

if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(
((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
)
nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)

ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
Expand Down Expand Up @@ -2122,14 +2072,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values):
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
rowptr = torch.zeros(
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
)
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
)
return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)


def coo2csc(cooA):
Expand All @@ -2138,14 +2084,10 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
colptr = torch.zeros(
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
)
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
return CSCSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)


def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
Expand All @@ -2170,9 +2112,7 @@ def double_quant(
rows = A.shape[0]

if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
A, threshold=threshold
)
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)

if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
Expand All @@ -2190,9 +2130,7 @@ def double_quant(
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
coo_tensor = coo_zeros(
A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
)
coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
ptrRowIdx = get_ptr(coo_tensor.rowidx)
ptrColIdx = get_ptr(coo_tensor.colidx)
ptrVal = get_ptr(coo_tensor.values)
Expand Down Expand Up @@ -2297,9 +2235,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No

def spmm_coo(cooA, B, out=None):
if out is None:
out = torch.empty(
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
)
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
Expand Down Expand Up @@ -2333,9 +2269,7 @@ def spmm_coo(cooA, B, out=None):

def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
if out is None:
out = torch.zeros(
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz
Expand Down Expand Up @@ -2443,9 +2377,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
dtype = x.dtype
x = x.float()
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
x, dim=dim, keepdim=True
)
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
dyna[dyna == 0] = 1
qx = 255.0 / dyna
minx = torch.amin(x, dim=dim, keepdim=True)
Expand Down Expand Up @@ -2553,9 +2485,7 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"

out = torch.zeros(
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
)
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)

idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
Expand Down
Loading

0 comments on commit aba3a96

Please sign in to comment.