Skip to content

Commit

Permalink
[Kernel] Bug fix for small_gemm_transb (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
pujiang2018 authored Apr 16, 2024
1 parent 5349b3b commit 15451f2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/kernels/gemm_kernel_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K
// Each loop compute 'BC' elements in C
int i = 0;
for (; i + BC - 1 < N; i += BC) {
const TA *pA = A + i * ldb;
const TA *pA = A;
const TB *pB = B + i * ldb;

__m512 vc[BC];
Expand All @@ -356,7 +356,7 @@ void small_gemm_transb_1xn_dynk(const TA *A, const TB *B, float *C, int N, int K

// Remain elements
for (; i < N; ++i) {
const TA *pA = A + i * ldb;
const TA *pA = A;
const TB *pB = B + i * ldb;
__m512 vc = _mm512_set1_ps(0);

Expand Down
24 changes: 16 additions & 8 deletions tests/ut/gemm_kernel_ext_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

#include "gtest/gtest.h"

template <typename TA, typename TB, typename TC>
static void small_gemm_tranb_ref(
const float *A, const float *B, float *C, int M, int N, int K, int lda, int ldb, int ldc) {
const TA *A, const TB *B, TC *C, int M, int N, int K, int lda, int ldb, int ldc) {
// Loop over the rows of A
for (int i = 0; i < M; i++) {
// Loop over the columns of B
for (int j = 0; j < N; j++) {
// Compute the dot product of row i of A with column j of B
float dot_product = 0;
for (int k = 0; k < K; k++) {
dot_product += A[i * lda + k] * B[j * ldb + k];
dot_product += (float)A[i * lda + k] * (float)B[j * ldb + k];
}
// Store the result in C[i][j]
C[i * ldc + j] = dot_product;
Expand All @@ -54,13 +55,14 @@ static void small_gemm_tranb_ref(
}

// Test function to compare reference and optimized implementations
template <typename TA = float, typename TB = float, typename TC = float>
void test_small_gemm_tranb(int M, int N, int K) {
float *A_ref = new float[M * K];
float *B_ref = new float[K * N];
float *C_ref = new float[M * N];
float *A_opt = new float[M * K];
float *B_opt = new float[K * N];
float *C_opt = new float[M * N];
TA *A_ref = new TA[M * K];
TB *B_ref = new TB[K * N];
TC *C_ref = new TC[M * N];
TA *A_opt = new TA[M * K];
TB *B_opt = new TB[K * N];
TC *C_opt = new TC[M * N];

// Generate random matrices A and B
std::random_device dev;
Expand Down Expand Up @@ -262,6 +264,12 @@ TEST(small_gemm_tranb, small_gemm_tranb_f32) {
test_bigger_kernel();
}

TEST(small_gemm_tranb, small_gemm_tranb_bf16fp16f32) {
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 2, 16);
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 4, 128);
test_small_gemm_tranb<bfloat16_t, float16_t, float>(1, 4, 256);
}

TEST(small_gemm_tranb, small_gemm_tranb_int8) {
test_small_gemm_tranb_int8(1, 100, 128);
test_small_gemm_tranb_int8(2, 101, 256);
Expand Down

0 comments on commit 15451f2

Please sign in to comment.