Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#14 from ROCm/fix_gemv_4bit
Browse files Browse the repository at this point in the history
improve the gemv 4bit accuracy by forcing the hipcub to 32
  • Loading branch information
Lzy17 authored Mar 19, 2024
2 parents f30dc38 + 3dc14e8 commit f4ac9ac
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
template<typename T, int BLOCK_SIZE, int NUM_MAX>
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
{
typedef hipcub::WarpReduce<T> WarpReduce;
typedef hipcub::WarpReduce<T, 32> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/8 , 8, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename LoadT::TempStorage loadt;
Expand Down Expand Up @@ -3553,7 +3553,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float> WarpReduce;
typedef hipcub::WarpReduce<float, 32> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];

const int warp_idx = threadIdx.x / 32;
Expand Down
1 change: 0 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,7 +2543,6 @@ def test_managed():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True'])
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
Expand Down

0 comments on commit f4ac9ac

Please sign in to comment.