Skip to content

Commit

Permalink
Bfp16 perf impr: change epack extraction algo from onlyc to cyx and s…
Browse files Browse the repository at this point in the history
…upport tunable epack (#2475)

* add support of multi k into xdlops gemm
* refactor xdlops for multiple k
* Add code to test tunable epack and change epack extraction algo from c to c*y*x
* load lds to register
* For bfp16/fp16 fwd case, extract epack from c*y*x. Make epack tunable for bfp16/fp16 fwd case.
Seperate out fwd and wrw kernels into different files
* Get tunable epack
* Use tuned epack in place of static epack
* Ensure the required LDS is computed correctly with PACKSize being tunable
* Address code review comments
Co-authored-by: Jing Zhang <[email protected]>
  • Loading branch information
TejashShah authored Apr 7, 2020
1 parent 481d6b9 commit cea6064
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 144 deletions.
6 changes: 4 additions & 2 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ struct PerformanceImplicitGemmXdlops : Serializable<PerformanceImplicitGemmXdlop
int KPerBlock; // 2^n[32..128]
int EPerBlock; // 2^n[4..16]
int EBlocks; // 2*n[1..64]
int EPACKSize; // 2*n[1..4] // 1 - fp32; 2,4 - bfp16; 4 - fp16

int GemmMPerWave;
int GemmNPerWave;
Expand All @@ -750,10 +751,10 @@ struct PerformanceImplicitGemmXdlops : Serializable<PerformanceImplicitGemmXdlop

bool use_spare_set;

PerformanceImplicitGemmXdlops(int, int, int, int, int, int, int, int, int, int, bool);
PerformanceImplicitGemmXdlops(int, int, int, int, int, int, int, int, int, int, int, bool);

PerformanceImplicitGemmXdlops()
: PerformanceImplicitGemmXdlops(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, false)
: PerformanceImplicitGemmXdlops(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, false)
{
}

Expand All @@ -766,6 +767,7 @@ struct PerformanceImplicitGemmXdlops : Serializable<PerformanceImplicitGemmXdlop
f(self.KPerBlock, "KPerBlock");
f(self.EPerBlock, "EPerBlock");
f(self.EBlocks, "EBlocks");
f(self.EPACKSize, "EPACKSize");
f(self.GemmMPerWave, "GemmMPerWave");
f(self.GemmNPerWave, "GemmNPerWave");
f(self.InBlockCopyClusterLengths_E, "InBlockCopyClusterLengths_E");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_FWD_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_FWD_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP

#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "gridwise_gemm_xdlops_fp16_bfp16.hpp"

namespace ck {

template <ImplicitGemmDirection conv_dir, index_t GemmKPACK>
struct make_vectorized_WeiDesc_Xdlops;

template <index_t GemmKPACK>
struct make_vectorized_WeiDesc_Xdlops<ImplicitGemmDirection::ForwardData, GemmKPACK>
{
template <typename WeiDesc>
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};

constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};

constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);

/* kpack comes from c*y*x */
static_assert((C * Y * X) % GemmKPACK == 0,
"C needs to be multiple of vectorized GemmKPACK");
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;

constexpr auto wei_gemmm_gemmk_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3),
make_tuple(PassThrough<K>{}, PassThrough<C * Y * X>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

constexpr auto wei_gemmm_gemmk_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<GemmK, GemmKPACK>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));

constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_gemmkpack_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<K>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));

return wei_gemmk_gemmm_gemmkpack_global_desc;
}
};

// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class ABFloat,
class AccFloat,
class CFloat,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmKBlocks,
index_t GemmKPACK,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmDataPerReadM,
index_t GemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPACK,
class GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPACK,
class GemmABlockCopyThreadClusterArrangeOrder,
class GemmABlockCopySrcAccessOrder,
class GemmABlockCopyDstAccessOrder,
index_t GemmABlockCopySrcDataPerRead_GemmKPACK,
index_t GemmABlockCopyDstDataPerWrite_GemmKPACK,
class GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPACK,
class GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPACK,
class GemmBBlockCopyThreadClusterArrangeOrder,
class GemmBBlockCopySrcAccessOrder,
class GemmBBlockCopyDstAccessOrder,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK,
ImplicitGemmDirection conv_dir>
struct
GridwiseConvolutionImplicitGemm_v4r4_gen_xdlops_fp16_bfp16_fwd_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const ABFloat* const __restrict__ p_in_global,
const ABFloat* const __restrict__ p_wei_global,
CFloat* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};

constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};

constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);

constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);

constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);

static_assert(C % GemmKPACK == 0, "C needs to be multiple of GemmKPACK");

constexpr index_t GemmM = K;
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
constexpr index_t GemmN = N * Ho * Wo;

// divide block work by [K, B]
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % (GemmKBlocks * GemmKPerBlock) == 0,
"wrong! cannot divide work evenly among block");

constexpr index_t GemmKSub = GemmK / GemmKBlocks;

// sanity-check for vectorized memory load
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];

constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];

static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");

constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));

constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];

constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));

constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));

constexpr auto in_gemmk_gemmkpack_gemmn_global_desc = transform_tensor_descriptor(
in_gemmk_gemmn_global_desc,
make_tuple(UnMerge<Sequence<GemmK, GemmKPACK>>{}, PassThrough<GemmN>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}));

constexpr auto in_gemmk_gemmn_gemmkpack_global_desc = transform_tensor_descriptor(
in_gemmk_gemmkpack_gemmn_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<GemmN>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));

constexpr auto in_gemmg_gemmk_gemmn_gemmkpack_global_desc =
transform_tensor_descriptor(in_gemmk_gemmn_gemmkpack_global_desc,
make_tuple(UnMerge<Sequence<GemmKBlocks, GemmKSub>>{},
PassThrough<GemmN>{},
PassThrough<GemmKPACK>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));

// weight tensor
// global mem
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc =
make_vectorized_WeiDesc_Xdlops<conv_dir, GemmKPACK>{}.get(wei_k_c_y_x_global_desc);

constexpr auto wei_gemmg_gemmk_gemmm_gemmkpack_global_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_gemmkpack_global_desc,
make_tuple(UnMerge<Sequence<GemmKBlocks, GemmKSub>>{},
PassThrough<GemmM>{},
PassThrough<GemmKPACK>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));

constexpr auto out_g_n_k_ho_wo_global_desc =
make_native_tensor_descriptor(Sequence<GemmKBlocks,
out_n_k_ho_wo_global_desc.GetLengths()[0],
out_n_k_ho_wo_global_desc.GetLengths()[1],
out_n_k_ho_wo_global_desc.GetLengths()[2],
out_n_k_ho_wo_global_desc.GetLengths()[3]>{},
Sequence<0,
out_n_k_ho_wo_global_desc.GetStrides()[0],
out_n_k_ho_wo_global_desc.GetStrides()[1],
out_n_k_ho_wo_global_desc.GetStrides()[2],
out_n_k_ho_wo_global_desc.GetStrides()[3]>{});

constexpr auto out_gemmg_gemmm_gemmn_global_desc = transform_tensor_descriptor(
out_g_n_k_ho_wo_global_desc,
make_tuple(
PassThrough<GemmKBlocks>{}, PassThrough<GemmM>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));

constexpr InMemoryDataOperation CGlobalMemoryDataOperation =
GemmKBlocks > 1 ? InMemoryDataOperation::AtomicAdd : InMemoryDataOperation::Set;

constexpr auto gridwise_gemm =
GridwiseBatchedGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1<
GridSize,
BlockSize,
ABFloat,
AccFloat,
CFloat,
decltype(wei_gemmg_gemmk_gemmm_gemmkpack_global_desc),
decltype(in_gemmg_gemmk_gemmn_gemmkpack_global_desc),
decltype(out_gemmg_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmDataPerReadM,
GemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmG_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmG_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterArrangeOrder,
GemmABlockCopySrcAccessOrder,
GemmABlockCopyDstAccessOrder,
3, // KPACK dimension
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmG_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmG_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterArrangeOrder,
GemmBBlockCopySrcAccessOrder,
GemmBBlockCopyDstAccessOrder,
2, // N dimension
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK,
CGlobalMemoryDataOperation>{};

gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};

} // namespace ck
#endif
Loading

0 comments on commit cea6064

Please sign in to comment.