Skip to content

Commit

Permalink
apply code-format changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MollySophia authored and github-actions[bot] committed Sep 12, 2024
1 parent 48054cd commit 82e9171
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 49 deletions.
84 changes: 45 additions & 39 deletions src/layer/arm/amx_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,41 @@
#define AMX_USABILITY_H

// From https:/corsix/amx/blob/main/aarch64.h
#define AMX_NOP_OP_IMM5(op, imm5) \
__asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory")

#define AMX_OP_GPR(op, gpr) \
__asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory")

#define AMX_LDX(gpr) AMX_OP_GPR( 0, gpr)
#define AMX_LDY(gpr) AMX_OP_GPR( 1, gpr)
#define AMX_STX(gpr) AMX_OP_GPR( 2, gpr)
#define AMX_STY(gpr) AMX_OP_GPR( 3, gpr)
#define AMX_LDZ(gpr) AMX_OP_GPR( 4, gpr)
#define AMX_STZ(gpr) AMX_OP_GPR( 5, gpr)
#define AMX_LDZI(gpr) AMX_OP_GPR( 6, gpr)
#define AMX_STZI(gpr) AMX_OP_GPR( 7, gpr)
#define AMX_EXTRX(gpr) AMX_OP_GPR( 8, gpr)
#define AMX_EXTRY(gpr) AMX_OP_GPR( 9, gpr)
#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr)
#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr)
#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr)
#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr)
#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr)
#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr)
#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr)
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr)
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr)
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)
#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56))
#define AMX_NOP_OP_IMM5(op, imm5) \
__asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" \
: \
: "i"(op), "i"(imm5) \
: "memory")

#define AMX_OP_GPR(op, gpr) \
__asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" \
: \
: "i"(op), "r"((uint64_t)(gpr)) \
: "memory")

#define AMX_LDX(gpr) AMX_OP_GPR(0, gpr)
#define AMX_LDY(gpr) AMX_OP_GPR(1, gpr)
#define AMX_STX(gpr) AMX_OP_GPR(2, gpr)
#define AMX_STY(gpr) AMX_OP_GPR(3, gpr)
#define AMX_LDZ(gpr) AMX_OP_GPR(4, gpr)
#define AMX_STZ(gpr) AMX_OP_GPR(5, gpr)
#define AMX_LDZI(gpr) AMX_OP_GPR(6, gpr)
#define AMX_STZI(gpr) AMX_OP_GPR(7, gpr)
#define AMX_EXTRX(gpr) AMX_OP_GPR(8, gpr)
#define AMX_EXTRY(gpr) AMX_OP_GPR(9, gpr)
#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr)
#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr)
#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr)
#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr)
#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr)
#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr)
#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr)
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr)
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr)
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)
#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t) & *(ptr)) + (((uint64_t)((row) + (flags)*64)) << 56))
void amx_set()
{
AMX_NOP_OP_IMM5(17, 0);
Expand All @@ -55,51 +61,51 @@ void amx_clr()
AMX_NOP_OP_IMM5(17, 1);
}

void amx_ldx(bool pair, unsigned int x_row, const void * ptr)
void amx_ldx(bool pair, unsigned int x_row, const void* ptr)
{
if (x_row >= 8)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDX(oprand);
}

void amx_ldy(bool pair, unsigned int y_row, const void * ptr)
void amx_ldy(bool pair, unsigned int y_row, const void* ptr)
{
if (y_row >= 8)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDY(oprand);
}

void amx_ldz(bool pair, unsigned int z_row, const void * ptr)
void amx_ldz(bool pair, unsigned int z_row, const void* ptr)
{
if (z_row >= 64)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDZ(oprand);
}

void amx_stz(bool pair, unsigned int z_row, const void * ptr)
void amx_stz(bool pair, unsigned int z_row, const void* ptr)
{
if (z_row >= 64)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_STZ(oprand);
}

Expand All @@ -116,7 +122,7 @@ void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
oprand |= ((uint64_t)y_mode & 0x3) << 37;
oprand |= ((uint64_t)x_mask & 0x1F) << 41;
oprand |= ((uint64_t)x_mode & 0x3) << 46;

AMX_FMA16(oprand);
}

Expand All @@ -138,7 +144,7 @@ void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
oprand |= ((uint64_t)y_mode & 0x3) << 37;
oprand |= ((uint64_t)x_mask & 0x1F) << 41;
oprand |= ((uint64_t)x_mode & 0x3) << 46;

AMX_FMA32(oprand);
}

Expand Down
23 changes: 13 additions & 10 deletions src/layer/arm/convolution_im2col_gemm_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -3056,20 +3056,20 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
if (pC)
{
for (int r = 0; r < 12; r++)
amx_ldz(false, 2*r, pC);
amx_ldz(false, 2 * r, pC);
}
else
{
__fp16 sums[16];
memset(sums, 0, 16 * sizeof(__fp16));
for (int r = 0; r < 12; r++)
amx_ldz(false, 2*r, sums);
amx_ldz(false, 2 * r, sums);
}
}
else
{
for (int r = 0; r < 12; r++)
amx_ldz(false, 2*r, outptr + 8 * r);
amx_ldz(false, 2 * r, outptr + 8 * r);
}

int kk = 0;
Expand All @@ -3088,17 +3088,19 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
if (out_elempack == 8)
{
__fp16 tmp[96 + 24];
for (int r = 0; r < 12; r++) {
amx_stz(false, 2*r, tmp + r * 8);
for (int r = 0; r < 12; r++)
{
amx_stz(false, 2 * r, tmp + r * 8);
}
memcpy(outptr0, tmp, 96 * sizeof(__fp16));
outptr0 += 96;
}
if (out_elempack == 4)
{
__fp16 tmp[32];
for (int r = 0; r < 12; r++) {
amx_stz(false, 2*r, tmp);
for (int r = 0; r < 12; r++)
{
amx_stz(false, 2 * r, tmp);
float16x8_t _tmp = vld1q_f16(tmp);
vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp));
vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp));
Expand Down Expand Up @@ -3167,8 +3169,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
else
{
__fp16 tmp[32];
for (int r = 0; r < 12; r++) {
amx_stz(false, 2*r, tmp);
for (int r = 0; r < 12; r++)
{
amx_stz(false, 2 * r, tmp);
memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16));
}
}
Expand Down Expand Up @@ -4915,7 +4918,7 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob,
bool k_end = k + TILE_K >= K;

#if __aarch64__ && NCNN_APPLE_AMX
// #if 0
// #if 0
if (amx_supported)
{
convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end);
Expand Down

0 comments on commit 82e9171

Please sign in to comment.