Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#12 from ROCm/igemm_workspace
Browse files Browse the repository at this point in the history
remove workspace in igemmlt
  • Loading branch information
pnunna93 authored Mar 12, 2024
2 parents 1b6dd48 + 7077195 commit fc9bf4d
Showing 1 changed file with 11 additions and 22 deletions.
33 changes: 11 additions & 22 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
else
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));

//Set User Preference attributes
int64_t max_workspace_size = 32 * 1024 * 1024 * 4;
void* d_workspace;
//NEED HIP CHECK ERROR
//hipMalloc(&d_workspace, max_workspace_size);
const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel

if(DTYPE_OUT == 32)
{
Expand Down Expand Up @@ -580,17 +576,14 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
heuristicResult,
&returnedAlgoCount));

auto toMalloc = max(heuristicResult[0].workspaceSize, max_workspace_size);

//printf("\n\n1Got algosn: %d %d %d\n\n",returnedAlgoCount, heuristicResult[0].workspaceSize, toMalloc);
//NEED HIP CHECK ERROR
auto err = hipMalloc(&d_workspace, toMalloc);
//printf("Hipmalloc\n");
//printf(hipError_to_string(err).c_str());
//printf("\n");
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, toMalloc, 0));
//hipStreamSynchronize(0);
hipFree(d_workspace);
if (returnedAlgoCount == 0)
{
has_error = 1;
}
else
{
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}
}
else
{
Expand Down Expand Up @@ -622,23 +615,19 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
heuristicResult,
&returnedAlgoCount));

//NEED HIP CHECK ERROR
hipMalloc(&d_workspace, heuristicResult[0].workspaceSize);
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}
else
{
//has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
float beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}

hipFree(d_workspace);
}


Expand Down

0 comments on commit fc9bf4d

Please sign in to comment.