Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#8 from ROCm/enable_transfo…
Browse files Browse the repository at this point in the history
…rm_with_transpose

Enable transform with transpose
  • Loading branch information
pnunna93 authored Feb 23, 2024
2 parents 2b77380 + 8c3476f commit 386e16c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
8 changes: 2 additions & 6 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,7 @@ def get_transform_buffer(
state = (shape[::-1], to_order)

if to_order == "row" or to_order == "col":
if HIP_ENVIRONMENT and to_order == "col":
# row to col transformation transposes output shape, so change buffer allocation accordingly
return init_func(shape[::-1], dtype=dtype, device=device), state
else:
return init_func(shape, dtype=dtype, device=device), state
return init_func(shape, dtype=dtype, device=device), state
elif to_order == "col32":
# blocks of 32 columns (padded)
cols = 32 * ((cols + 31) // 32)
Expand Down Expand Up @@ -503,7 +499,7 @@ def nvidia_transform(
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1]
state[0], A.dtype, A.device, to_order, state[1], transpose
)
else:
new_state = (state[1], to_order)
Expand Down
22 changes: 19 additions & 3 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,12 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
hipblasLtOrder_t orderA = get_order<SRC>();
hipblasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
int ldOut;
if (TARGET==COL && transpose) {
ldOut = dim2;
} else {
ldOut = get_leading_dim<TARGET>(dim1, dim2);
}

hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL, B_desc = NULL;
T B = T(0);
Expand All @@ -395,13 +400,21 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
{
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_8I, dim1, dim2, ldA));
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_8I, 0, 0, 0));
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut));
if (TARGET==COL && transpose) {
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim2, dim1, ldOut));
} else {
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut));
}
}
else if(DTYPE == 32)
{
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_32I, dim1, dim2, ldA));
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_32I, 0, 0, 0));
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut));
if (TARGET==COL && transpose) {
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim2, dim1, ldOut));
} else {
checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut));
}
}
else
{
Expand All @@ -424,6 +437,9 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
}

template void transform<int8_t, ROW, COL, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL, true, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL, false, 32>(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL, true, 32>(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, ROW, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL32, false, 8>(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL32, false, 32>(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
Expand Down
6 changes: 6 additions & 0 deletions csrc/pythonInterface.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(hipblasLtHandle_t lt
#endif

MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8);
MAKE_FUNC_TRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32);
MAKE_FUNC_TRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32);
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
Expand Down Expand Up @@ -406,6 +409,9 @@ extern "C"


MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8)
MAKE_FUNC_CTRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32)
MAKE_FUNC_CTRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32)
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
Expand Down
21 changes: 9 additions & 12 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,19 +719,16 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_close(C1, C3.float())

# Since ROCm supports row to col transformation only which is same as transpose,
# skipping this for HIP environment
if not HIP_ENVIRONMENT:
## transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.float())
## transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.float())

B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_close(C1, C3.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_close(C1, C3.float())


dim1 = [32]
Expand Down

0 comments on commit 386e16c

Please sign in to comment.