From 707719568b194530f45795711a1122a8dcfff9b5 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Tue, 12 Mar 2024 02:02:34 +0000 Subject: [PATCH] remove workspace in igemmlt --- csrc/ops.hip | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index 27e479573..2693ffa63 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -536,11 +536,7 @@ template int igemmlt(hipblasLtHandl else has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - //Set User Preference attributes - int64_t max_workspace_size = 32 * 1024 * 1024 * 4; - void* d_workspace; - //NEED HIP CHECK ERROR - //hipMalloc(&d_workspace, max_workspace_size); + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel if(DTYPE_OUT == 32) { @@ -580,17 +576,14 @@ template int igemmlt(hipblasLtHandl heuristicResult, &returnedAlgoCount)); - auto toMalloc = max(heuristicResult[0].workspaceSize, max_workspace_size); - - //printf("\n\n1Got algosn: %d %d %d\n\n",returnedAlgoCount, heuristicResult[0].workspaceSize, toMalloc); - //NEED HIP CHECK ERROR - auto err = hipMalloc(&d_workspace, toMalloc); - //printf("Hipmalloc\n"); - //printf(hipError_to_string(err).c_str()); - //printf("\n"); - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, toMalloc, 0)); -//hipStreamSynchronize(0); - hipFree(d_workspace); + if (returnedAlgoCount == 0) + { + has_error = 1; + } + else + { + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } } else { @@ -622,23 +615,19 @@ template int igemmlt(hipblasLtHandl heuristicResult, &returnedAlgoCount)); - //NEED HIP CHECK ERROR - hipMalloc(&d_workspace, heuristicResult[0].workspaceSize); if(!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); } else { //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); float beta = 0.0f; - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); } - - hipFree(d_workspace); }