Skip to content

Commit

Permalink
INTQ (8-bit) Asymmetric Quantization Support
Browse files Browse the repository at this point in the history
Added Asymmetric UINT8 quantized convolutional layers support (a la' Tensorflow).

 ### Added
 - added proper license header
 - added Asymmetric UINT8 Convolution, AvgPool, FC layers
 - added Asymmetric UINT8 support functions
 - added 64-bit '__HI_SMLAL()' support function
  • Loading branch information
alessandrocapotondi committed Mar 12, 2019
1 parent d63b089 commit 47a65b1
Show file tree
Hide file tree
Showing 17 changed files with 2,153 additions and 251 deletions.
247 changes: 204 additions & 43 deletions CMSIS/NN/Include/arm_nnfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,37 @@ extern "C"
q15_t * pOut,
q15_t * vec_buffer);

/**
* @brief uint8 asymmetric opt fully-connected layer function
* @param[in] pV pointer to input vector
* @param[in] pM pointer to matrix weights
* @param[in] dim_vec length of the vector
* @param[in] num_of_rows number of rows in weight matrix
* @param[in] z_wt weights offset
* @param[in] z_in input offset
* @param[in] z_out output offset
* @param[in] m_zero m zero quantization param
* @param[in] n_zero n zero quantization param
* @param[in] bias pointer to bias
* @param[in,out] pOut pointer to output vector
* @param[in,out] vec_buffer pointer to buffer space for input
* @return The function returns <code>ARM_MATH_SUCCESS</code>
*
*/
arm_status arm_fully_connected_uint8_asym(const uint8_t * pV,
const uint8_t * pM,
const uint16_t dim_vec,
const uint16_t num_of_rows,
const uint8_t z_wt,
const uint8_t z_in,
const uint8_t z_out,
const int32_t m_zero,
const uint16_t n_zero,
const int32_t * bias,
uint8_t * pOut,
int16_t * vec_buffer);


/**
* @brief Matrix-Multiplication Kernels for Convolution
*
Expand Down Expand Up @@ -955,6 +986,30 @@ extern "C"
q7_t * bufferA,
q7_t * Im_out);

/**
* @brief Asymmetric UINT8 max pooling function
* @param[in,out] Im_in pointer to input tensor
* @param[in] dim_im_in input tensor dimension
* @param[in] ch_im_in number of input tensor channels
* @param[in] dim_kernel filter kernel size
* @param[in] padding padding sizes
* @param[in] stride convolution stride
* @param[in] dim_im_out output tensor dimension
* @param[in,out] bufferA pointer to buffer space for input
* @param[in,out] Im_out pointer to output tensor
* @return none.
*/
void arm_maxpool_asym_uint8_HWC(uint8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint16_t dim_kernel,
const uint16_t padding,
const uint16_t stride,
const uint16_t dim_im_out,
int16_t * bufferA,
uint8_t * Im_out);


/**
* @brief Q7 average pooling function
* @param[in] Im_in pointer to input tensor
Expand All @@ -979,6 +1034,28 @@ extern "C"
const uint16_t dim_im_out,
q7_t * bufferA,
q7_t * Im_out);
/**
* @brief Asymmetric UINT8 average pooling function
* @param[in,out] Im_in pointer to input tensor
* @param[in] dim_im_in input tensor dimension
* @param[in] ch_im_in number of input tensor channels
* @param[in] dim_kernel filter kernel size
* @param[in] padding padding sizes
* @param[in] stride convolution stride
* @param[in] dim_im_out output tensor dimension
* @param[in,out] bufferA pointer to buffer space for input
* @param[in,out] Im_out pointer to output tensor
* @return none.
*/
void arm_avepool_asym_uint8_HWC(uint8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint16_t dim_kernel,
const uint16_t padding,
const uint16_t stride,
const uint16_t dim_im_out,
int16_t * bufferA,
uint8_t * Im_out);

/**
* @defgroup Softmax Softmax Functions
Expand Down Expand Up @@ -1012,26 +1089,25 @@ extern "C"


/*
* Quantized Convolutional Layers
*
* INT-Q quantized layers
*/
arm_status
arm_convolve_HWC_BIN_fast(const uint32_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint32_t * wt,
const uint16_t ch_im_out,
const uint16_t dim_kernel,
const uint16_t padding,
const uint16_t stride,
uint8_t * Im_out,
const uint16_t dim_im_out,
uint32_t * bufferA,
const int16_t * pThreshold,
int8_t * bufferB);

arm_status
arm_convolve_HWC_INT2_fast( const int8_t * Im_in,
arm_status arm_convolve_HWC_int1(
const uint32_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint32_t * wt,
const uint16_t ch_im_out,
const uint16_t dim_kernel,
const uint16_t padding,
const uint16_t stride,
uint8_t * Im_out,
const uint16_t dim_im_out,
uint32_t * bufferA,
const int16_t * pThreshold,
int8_t * bufferB);

arm_status arm_convolve_HWC_int2(
const int8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const int8_t * wt,
Expand All @@ -1045,8 +1121,8 @@ extern "C"
const int16_t * pThreshold,
int8_t * bufferB);

arm_status
arm_convolve_HWC_INT4_fast( const int8_t * Im_in,
arm_status arm_convolve_HWC_int4(
const int8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const int8_t * wt,
Expand All @@ -1060,28 +1136,113 @@ extern "C"
const int16_t * pThreshold,
int8_t * bufferB);

int8_t *arm_nn_mat_mult_kernel_int2_int16_reordered(const int8_t * pA,
const int16_t * pInBuffer,
const uint16_t ch_im_out,
const uint16_t numCol_A,
const int16_t * pThreshold,
int8_t * pOut);

int8_t *arm_nn_mat_mult_kernel_int4_int16_reordered(const int8_t * pA,
const int16_t * pInBuffer,
const uint16_t ch_im_out,
const uint16_t numCol_A,
const int16_t * pThreshold,
int8_t * pOut);

uint32_t *arm_nn_mat_mult_kernel_BIN_reordered( const uint32_t * pA,
const uint32_t * pInBuffer,
const uint16_t ch_im_out,
const uint32_t numCol_A,
const int16_t * pThreshold,
uint32_t * pOut);


int8_t *arm_nn_mat_mult_kernel_int2_int16_reordered(
const int8_t * pA,
const int16_t * pInBuffer,
const uint16_t ch_im_out,
const uint16_t numCol_A,
const int16_t * pThreshold,
int8_t * pOut);

int8_t *arm_nn_mat_mult_kernel_int4_int16_reordered(
const int8_t * pA,
const int16_t * pInBuffer,
const uint16_t ch_im_out,
const uint16_t numCol_A,
const int16_t * pThreshold,
int8_t * pOut);

uint32_t *arm_nn_mat_mult_kernel_int1_reordered(
const uint32_t * pA,
const uint32_t * pInBuffer,
const uint16_t ch_im_out,
const uint32_t numCol_A,
const int16_t * pThreshold,
uint32_t * pOut);

arm_status arm_convolve_1x1_HWC_uint8_asym_fast_nonsquare(
const uint8_t * Im_in,
const uint16_t dim_im_in_x,
const uint16_t dim_im_in_y,
const uint16_t ch_im_in,
const uint8_t * wt,
const uint8_t z_wt,
const uint8_t z_in,
const uint8_t z_out,
const int16_t m_zero,
const uint16_t n_zero,
const uint16_t ch_im_out,
const uint16_t dim_kernel_x,
const uint16_t dim_kernel_y,
const uint16_t padding_x,
const uint16_t padding_y,
const uint16_t stride_x,
const uint16_t stride_y,
const int32_t * bias,
uint8_t * Im_out,
const uint16_t dim_im_out_x,
const uint16_t dim_im_out_y,
int16_t * bufferA,
uint8_t * bufferB);

arm_status arm_convolve_HWC_asym_uint8(
const uint8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint8_t * wt,
const uint8_t z_wt,
const uint8_t z_in,
const uint8_t z_out,
const int32_t m_zero,
const uint16_t n_zero,
const uint16_t ch_im_out,
const uint16_t dim_kernel,
const uint8_t left_padding,
const uint8_t right_padding,
const uint8_t top_padding,
const uint8_t bottom_padding,
const uint16_t stride,
const int32_t * bias,
uint8_t * Im_out,
const uint16_t dim_im_out,
int16_t * bufferA,
uint8_t * bufferB);

arm_status arm_depthwise_separable_conv_HWC_asym_uint8(
const uint8_t * Im_in,
const uint16_t dim_im_in,
const uint16_t ch_im_in,
const uint8_t * wt,
const uint8_t z_wt,
const uint8_t z_in,
const uint8_t z_out,
const int32_t m_zero,
const uint16_t n_zero,
const uint16_t ch_im_out,
const uint16_t dim_kernel,
const uint8_t left_padding,
const uint8_t right_padding,
const uint8_t top_padding,
const uint8_t bottom_padding,
const uint16_t stride,
const int32_t * bias,
uint8_t * Im_out,
const uint16_t dim_im_out,
int16_t * bufferA,
uint8_t * bufferB);

uint8_t *arm_nn_mat_mult_kernel_asym_uint8_int16_reordered(
const uint8_t * pA,
const int16_t * pInBuffer,
const uint8_t z_a,
const uint8_t z_b,
const uint8_t z_out,
const int32_t m_zero,
const uint16_t n_zero,
const uint16_t ch_im_out,
const uint16_t numCol_A,
const int32_t * bias,
uint8_t * pOut);

#ifdef __cplusplus
}
Expand Down
73 changes: 68 additions & 5 deletions CMSIS/NN/Include/arm_nnsupportfunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,47 @@ void arm_q7_to_q15_no_shift(const q7_t * pSrc, q15_t * pDst, uint32_t block
* @param[in] *pSrc points to the Q7 input vector
* @param[out] *pDst points to the Q15 output vector
* @param[in] blockSize length of the input vector
* @return none.
*
* @return none.
*/

void arm_q7_to_q15_reordered_no_shift(const q7_t * pSrc, q15_t * pDst, uint32_t blockSize);

/**
* @brief Converts the elements of the INT4 vector to reordered INT16 vector without left-shift
* @param[in] *pSrc points to the INT4 input vector
* @param[out] *pDst points to the INT16 output vector
* @param[in] blockSize length of the input vector
* @return none.
*/
void arm_int4_to_int16_reordered_no_shift(const int8_t * pSrc, int16_t * pDst, uint32_t blockSize);
/**
* @brief Converts the elements of the INT2 vector to reordered INT16 vector without left-shift
* @param[in] *pSrc points to the INT2 input vector
* @param[out] *pDst points to the INT16 output vector
* @param[in] blockSize length of the input vector
* @return none.
*/
void arm_int2_to_int16_reordered_no_shift(const int8_t * pSrc, int16_t * pDst, uint32_t blockSize);

/**
* @brief Converts the elements of the Asymmetric UINT8 vector to INT16 vector without left-shift
* @param[in] *pSrc points to the Asymmetric UINT8 input vector
* @param[in] Asymmetric UINT8 offset
* @param[out] *pDst points to the INT16 output vector
* @param[in] blockSize length of the input vector
* @return none.
*/
void arm_asym_uint8_to_int16_no_shift(const uint8_t * pSrc, const uint8_t offset, int16_t * pDst, uint32_t blockSize);

/**
* @brief Converts the elements of the Asymmetric UINT8 vector to reordered INT16 vector without left-shift
* @param[in] *pSrc points to the Asymmetric UINT8 input vector
* @param[in] Asymmetric UINT8 offset
* @param[out] *pDst points to the INT16 output vector
* @param[in] blockSize length of the input vector
* @return none.
*/
void arm_asym_uint8_to_int16_reordered_no_shift(const uint8_t * pSrc, const uint8_t offset, int16_t * pDst, uint32_t blockSize);

#if defined (ARM_MATH_DSP)

/**
Expand Down Expand Up @@ -144,7 +179,7 @@ __STATIC_FORCEINLINE void *read_and_pad_reordered(void *source, q31_t * out1, q3
/**
* @brief read and expand one INT4 word into two INT16 words with reordering
*/
__STATIC_INLINE void *read_and_pad_reordered_INT4(void *source, int32_t * out1, int32_t * out2, int32_t * out3, int32_t * out4)
__STATIC_INLINE void *read_and_pad_reordered_int4(void *source, int32_t * out1, int32_t * out2, int32_t * out3, int32_t * out4)
{

#ifndef ARM_MATH_BIG_ENDIAN
Expand All @@ -168,7 +203,7 @@ __STATIC_INLINE void *read_and_pad_reordered_INT4(void *source, int32_t * out1,
/**
* @brief read and expand one INT2 word into two INT16 words with reordering
*/
__STATIC_INLINE void *read_and_pad_reordered_INT2( void *source, int32_t * out1, int32_t * out2, int32_t * out3, int32_t * out4,
__STATIC_INLINE void *read_and_pad_reordered_int2( void *source, int32_t * out1, int32_t * out2, int32_t * out3, int32_t * out4,
int32_t * out5, int32_t * out6, int32_t * out7, int32_t * out8)
{
q31_t inA = *__SIMD32(source)++;
Expand Down Expand Up @@ -196,6 +231,34 @@ __STATIC_INLINE void *read_and_pad_reordered_INT2( void *source, int32_t * out1
}


/*
* @brief read and expand four UINT8 into four INT16 with reordering
*/
__STATIC_INLINE void *read_and_pad_reordered_uint8(void *source, int32_t * out1, int32_t * out2)
{
int32_t inA = *__SIMD32(source)++;
#ifndef ARM_MATH_BIG_ENDIAN
*out2 = __UXTB16(__ROR(inA, 8));
*out1 = __UXTB16(inA);
#else
*out1 = __UXTB16(__ROR(inA, 8));
*out2 = __UXTB16(inA);
#endif

return source;
}

__STATIC_INLINE int32_t __HI_SMULL(int32_t a, int32_t b)
{
int hi = 0;
int lo = 0;
asm volatile ("SMULL %[lo_out], %[hi_out], %[a_operand], %[b_operand]"
: [lo_out] "=&r" (lo), [hi_out] "=&r" (hi)
: [a_operand] "r" (a), [b_operand] "r" (b)
);
return hi;
}

#endif

/**
Expand Down
Loading

0 comments on commit 47a65b1

Please sign in to comment.