@@ -9,6 +9,7 @@ genrule( | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemm --type tensorop1688 $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations gemv --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations deconv --type tensorop8816 $(@D) | |||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type simt $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8816 $(@D) | ||||
CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) | CUTLASS_WITH_LONG_PATH=true python3 $$GEN --operations conv2d --type tensorop8832 $(@D) | ||||
@@ -337,7 +337,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay | |||||
else: | else: | ||||
swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx | ||||
else: | else: | ||||
swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx | |||||
if implicit_gemm_mode == ImplicitGemmMode.GemmTN: | |||||
swizzling_functor = SwizzlingFunctor.ConvDgradTrans | |||||
else: | |||||
swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx | |||||
# skip rule | # skip rule | ||||
def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: | def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool: | ||||
@@ -36,6 +36,7 @@ if __name__ == "__main__": | |||||
write_op_list(f, "gemm", "tensorop884") | write_op_list(f, "gemm", "tensorop884") | ||||
write_op_list(f, "gemv", "simt") | write_op_list(f, "gemv", "simt") | ||||
write_op_list(f, "deconv", "simt") | write_op_list(f, "deconv", "simt") | ||||
write_op_list(f, "deconv", "tensorop8816") | |||||
write_op_list(f, "conv2d", "simt") | write_op_list(f, "conv2d", "simt") | ||||
write_op_list(f, "conv2d", "tensorop8816") | write_op_list(f, "conv2d", "tensorop8816") | ||||
write_op_list(f, "conv2d", "tensorop8832") | write_op_list(f, "conv2d", "tensorop8832") | ||||
@@ -445,6 +445,53 @@ def GenerateDeconv_Simt(args): | |||||
use_special_optimization) | use_special_optimization) | ||||
return operations | return operations | ||||
def GenerateDeconv_TensorOp_8816(args): | |||||
operations = [] | |||||
layouts = [ | |||||
(LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32), | |||||
(LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64), | |||||
(LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128), | |||||
] | |||||
math_instructions = [ | |||||
MathInstruction( \ | |||||
[8, 8, 16], \ | |||||
DataType.s8, DataType.s8, DataType.s32, \ | |||||
OpcodeClass.TensorOp, \ | |||||
MathOperation.multiply_add_saturate), | |||||
] | |||||
dst_layouts = [ | |||||
LayoutType.TensorNHWC, | |||||
] | |||||
dst_types = [ | |||||
DataType.s8, | |||||
] | |||||
use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling | |||||
min_cc = 75 | |||||
max_cc = 1024 | |||||
cuda_major = 10 | |||||
cuda_minor = 2 | |||||
for math_inst in math_instructions: | |||||
for layout in layouts: | |||||
for dst_type, dst_layout in zip(dst_types, dst_layouts): | |||||
tile_descriptions = [ | |||||
TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), | |||||
TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), | |||||
] | |||||
for tile in tile_descriptions: | |||||
dst_align = 32 if tile.threadblock_shape[1] == 16 else 64 | |||||
operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type, | |||||
min_cc, layout[2], layout[2], dst_align, use_special_optimization, | |||||
ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor) | |||||
return operations | |||||
################################################################################ | ################################################################################ | ||||
# parameters | # parameters | ||||
# Edge - for tiles, the edges represent the length of one side | # Edge - for tiles, the edges represent the length of one side | ||||
@@ -820,9 +867,12 @@ def GenerateConv2dOperations(args): | |||||
return GenerateConv2d_TensorOp_8832(args) | return GenerateConv2d_TensorOp_8832(args) | ||||
def GenerateDeconvOperations(args): | def GenerateDeconvOperations(args): | ||||
assert args.type == "simt", "operation deconv only support" \ | |||||
"simt. (got:{})".format(args.type) | |||||
return GenerateDeconv_Simt(args) | |||||
if args.type == "simt": | |||||
return GenerateDeconv_Simt(args) | |||||
else: | |||||
assert args.type == "tensorop8816", "operation deconv only support" \ | |||||
"simt and tensorop8816. (got:{})".format(args.type) | |||||
return GenerateDeconv_TensorOp_8816(args) | |||||
def GenerateGemmOperations(args): | def GenerateGemmOperations(args): | ||||
if args.type == "tensorop884": | if args.type == "tensorop884": | ||||
@@ -280,6 +280,9 @@ class LayoutType(enum.Enum): | |||||
TensorC32RSK32 = enum_auto() | TensorC32RSK32 = enum_auto() | ||||
TensorC64RSK64 = enum_auto() | TensorC64RSK64 = enum_auto() | ||||
TensorK4RSC4 = enum_auto() | TensorK4RSC4 = enum_auto() | ||||
TensorCK4RS4 = enum_auto() | |||||
TensorCK8RS8 = enum_auto() | |||||
TensorCK16RS16 = enum_auto() | |||||
# | # | ||||
LayoutTag = { | LayoutTag = { | ||||
@@ -303,7 +306,10 @@ LayoutTag = { | |||||
LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', | LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', | ||||
LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', | LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', | ||||
LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', | LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', | ||||
LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', | |||||
LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', | |||||
LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>', | |||||
LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>', | |||||
LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>', | |||||
} | } | ||||
# | # | ||||
@@ -342,6 +348,9 @@ ShortLayoutTypeNames = { | |||||
LayoutType.TensorC32RSK32: 'c32rsk32', | LayoutType.TensorC32RSK32: 'c32rsk32', | ||||
LayoutType.TensorC64RSK64: 'c64rsk64', | LayoutType.TensorC64RSK64: 'c64rsk64', | ||||
LayoutType.TensorK4RSC4: 'k4rsc4', | LayoutType.TensorK4RSC4: 'k4rsc4', | ||||
LayoutType.TensorCK4RS4: 'ck4rs4', | |||||
LayoutType.TensorCK8RS8: 'ck8rs8', | |||||
LayoutType.TensorCK16RS16: 'ck16rs16', | |||||
} | } | ||||
# | # | ||||
@@ -484,6 +493,7 @@ class SwizzlingFunctor(enum.Enum): | |||||
ConvFpropNCxHWx = enum_auto() | ConvFpropNCxHWx = enum_auto() | ||||
ConvFpropTrans = enum_auto() | ConvFpropTrans = enum_auto() | ||||
ConvDgradNCxHWx = enum_auto() | ConvDgradNCxHWx = enum_auto() | ||||
ConvDgradTrans = enum_auto() | |||||
# | # | ||||
SwizzlingFunctorTag = { | SwizzlingFunctorTag = { | ||||
@@ -494,6 +504,7 @@ SwizzlingFunctorTag = { | |||||
SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', | ||||
SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', | ||||
SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', | ||||
SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle', | |||||
} | } | ||||
################################################################################################### | ################################################################################################### | ||||
@@ -464,6 +464,19 @@ cutlass_gen_list = [ | |||||
"cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | ||||
"cutlass_simt_s8_idgrad_s2_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | "cutlass_simt_s8_idgrad_s2_id_s8_16x64x8_16x64x8_2_nc4hw4_k4rsc4.cu", | ||||
"all_deconv_simt_operations.cu", | "all_deconv_simt_operations.cu", | ||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck4rs4.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck4rs4.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck8rs8.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck8rs8.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_128x32x32_64x32x32_1_nhwc_ck16rs16.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", | |||||
"cutlass_tensorop_s8_i8816dgrad_s2_id_s8_64x16x32_64x16x32_2_nhwc_ck16rs16.cu", | |||||
"all_deconv_tensorop8816_operations.cu", | |||||
"cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_1x1_id_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
"cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | "cutlass_simt_s8_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_c4rsk4.cu", | ||||
@@ -155,6 +155,7 @@ if(MGE_WITH_CUDA) | |||||
gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | gen_cutlass_kimpl(gemm tensorop1688 CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(gemv simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | gen_cutlass_kimpl(deconv simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(deconv tensorop8816 CUTLASS_SOURCES) | |||||
gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | gen_cutlass_kimpl(conv2d simt CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) | gen_cutlass_kimpl(conv2d tensorop8816 CUTLASS_SOURCES) | ||||
gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) | gen_cutlass_kimpl(conv2d tensorop8832 CUTLASS_SOURCES) | ||||
@@ -36,6 +36,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||||
int8_algos.push_back(&algo); | int8_algos.push_back(&algo); | ||||
} | } | ||||
fill_int8_imma_algos(); | |||||
for (auto&& algo : int8_nhwc_imma) { | |||||
all_algos.push_back(&algo); | |||||
int8_algos.push_back(&algo); | |||||
} | |||||
int8_algos.push_back(&int8_nchw_dotprod); | int8_algos.push_back(&int8_nchw_dotprod); | ||||
all_algos.push_back(&int8_nchw_dotprod); | all_algos.push_back(&int8_nchw_dotprod); | ||||
@@ -40,7 +40,8 @@ public: | |||||
CUDA_BFLOAT16, | CUDA_BFLOAT16, | ||||
CUDA_GROUP_CONV_GENERAL, | CUDA_GROUP_CONV_GENERAL, | ||||
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | ||||
CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8 | |||||
CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8, | |||||
CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8 | |||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
@@ -299,11 +300,53 @@ private: | |||||
const void* get_available_op(const SizeArgs& args) const; | const void* get_available_op(const SizeArgs& args) const; | ||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm final | |||||
: public AlgoBase { | |||||
public: | |||||
struct AlgoParam { | |||||
int threadblock_m; | |||||
int threadblock_n; | |||||
int threadblock_k; | |||||
int warp_m; | |||||
int warp_n; | |||||
int warp_k; | |||||
int stage; | |||||
int access_size; | |||||
std::string to_string() { | |||||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, | |||||
threadblock_n, threadblock_k, warp_m, warp_n, | |||||
warp_k, stage, access_size); | |||||
} | |||||
}; | |||||
AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param) | |||||
: m_algo_param{algo_param}, | |||||
m_name{ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", | |||||
m_algo_param.to_string().c_str())} {} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8) | |||||
private: | |||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||||
const SizeArgs& args) const; | |||||
const void* get_available_op(const SizeArgs& args) const; | |||||
void reorder_filter(const ExecArgs& args, const int iterleaved, | |||||
int8_t* reordered_filter) const; | |||||
AlgoParam m_algo_param; | |||||
std::string m_name; | |||||
}; | |||||
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | ||||
// defined in cudnn.cpp | // defined in cudnn.cpp | ||||
void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
// defined in implicit_gemm_int8_nchw4_dp4a.cpp | // defined in implicit_gemm_int8_nchw4_dp4a.cpp | ||||
void fill_int8_dp4a_algos(); | void fill_int8_dp4a_algos(); | ||||
// defined in implicit_gemm_int8_nhwc_imma.cpp | |||||
void fill_int8_imma_algos(); | |||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
@@ -318,6 +361,7 @@ public: | |||||
AlgoGroupConvGeneral group; | AlgoGroupConvGeneral group; | ||||
std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | ||||
AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; | AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; | ||||
std::vector<AlgoInt8NHWCIMMAImplicitGemm> int8_nhwc_imma; | |||||
std::vector<AlgoBase*> | std::vector<AlgoBase*> | ||||
//! all algorithms | //! all algorithms | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | #include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | ||||
#include "src/cuda/transpose_utils.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -21,7 +22,6 @@ using namespace deconv; | |||||
namespace { | namespace { | ||||
// | |||||
__global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( | __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( | ||||
int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, | int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, | ||||
uint32_t IC, uint32_t FHFW) { | uint32_t IC, uint32_t FHFW) { | ||||
@@ -30,32 +30,55 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( | |||||
const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; | const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; | ||||
if (fhfw < FHFW && icb < IC / 4) { | if (fhfw < FHFW && icb < IC / 4) { | ||||
int src0 = *reinterpret_cast<const int*>( | |||||
src + (ocb * 4 + 0) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||||
int src1 = *reinterpret_cast<const int*>( | |||||
src + (ocb * 4 + 1) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||||
int src2 = *reinterpret_cast<const int*>( | |||||
src + (ocb * 4 + 2) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||||
int src3 = *reinterpret_cast<const int*>( | |||||
src + (ocb * 4 + 3) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||||
int src_value[4], dst_value[4]; | |||||
#pragma unroll | |||||
for (int i = 0; i < 4; i++) { | |||||
src_value[i] = *reinterpret_cast<const int*>( | |||||
src + (ocb * 4 + i) * IC * FHFW + (icb * FHFW + fhfw) * 4); | |||||
} | |||||
// transpose 4x4 | // transpose 4x4 | ||||
int dst01_lo = __byte_perm(src0, src1, 0x5140); | |||||
int dst01_hi = __byte_perm(src0, src1, 0x7362); | |||||
int dst23_lo = __byte_perm(src2, src3, 0x5140); | |||||
int dst23_hi = __byte_perm(src2, src3, 0x7362); | |||||
int dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); | |||||
int dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); | |||||
int dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); | |||||
int dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); | |||||
*reinterpret_cast<int*>( | |||||
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 0) * 4) = dst0; | |||||
*reinterpret_cast<int*>( | |||||
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 1) * 4) = dst1; | |||||
*reinterpret_cast<int*>( | |||||
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 2) * 4) = dst2; | |||||
*reinterpret_cast<int*>( | |||||
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + 3) * 4) = dst3; | |||||
transpose_int8_interleavedx4<4, int>(src_value, dst_value); | |||||
#pragma unroll | |||||
for (int i = 0; i < 4; i++) { | |||||
*reinterpret_cast<int*>( | |||||
dst + (ocb * FHFW * IC + fhfw * IC + icb * 4 + i) * 4) = | |||||
dst_value[i]; | |||||
} | |||||
} | |||||
} | |||||
template <uint32_t interleaved, typename vec_type> | |||||
__global__ void reorder_filter_nhwc_to_cnxhwx_kernel( | |||||
int8_t* __restrict__ dst, const int8_t* __restrict__ src, uint32_t OC, | |||||
uint32_t IC, uint32_t FHFW) { | |||||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
const int32_t ocb = lane / (FHFW * IC / 4); | |||||
const int32_t fhfw_icb = lane % (FHFW * IC / 4); | |||||
const int32_t fhfw = fhfw_icb / (IC / 4); | |||||
const int32_t icb = fhfw_icb % (IC / 4); | |||||
if (ocb < OC / interleaved && fhfw < FHFW) { | |||||
int src_value[interleaved]; | |||||
vec_type dst_value[4]; | |||||
#pragma unroll | |||||
for (int i = 0; i < interleaved; i++) { | |||||
src_value[i] = *reinterpret_cast<const int*>( | |||||
src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC + | |||||
icb * 4); | |||||
} | |||||
transpose_int8_interleavedx4<interleaved, vec_type>(src_value, | |||||
dst_value); | |||||
#pragma unroll | |||||
for (int i = 0; i < 4; i++) { | |||||
*reinterpret_cast<vec_type*>(dst + (icb * 4 + i) * FHFW * OC + | |||||
(ocb * FHFW + fhfw) * interleaved) = | |||||
dst_value[i]; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -73,4 +96,27 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
void megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( | |||||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, uint32_t interleaved, cudaStream_t stream) { | |||||
int32_t vthreads = OC / interleaved * IC / 4 * FH * FW; | |||||
int32_t nr_threads = std::min(256, vthreads); | |||||
int32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
if (interleaved == 4) { | |||||
reorder_filter_nhwc_to_cnxhwx_kernel<4, int> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||||
FH * FW); | |||||
} else if (interleaved == 8) { | |||||
reorder_filter_nhwc_to_cnxhwx_kernel<8, int2> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||||
FH * FW); | |||||
} else { | |||||
reorder_filter_nhwc_to_cnxhwx_kernel<16, int4> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>(dst, src, OC, IC, | |||||
FH * FW); | |||||
} | |||||
after_kernel_launch(); | |||||
} | |||||
// vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen |
@@ -20,6 +20,10 @@ void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src, | |||||
uint32_t OC, uint32_t IC, uint32_t FH, | uint32_t OC, uint32_t IC, uint32_t FH, | ||||
uint32_t FW, cudaStream_t stream); | uint32_t FW, cudaStream_t stream); | ||||
void reorder_filter_nhwc_to_cnxhwx(int8_t* dst, const int8_t* src, uint32_t OC, | |||||
uint32_t IC, uint32_t FH, uint32_t FW, | |||||
uint32_t interleaved, cudaStream_t stream); | |||||
} // namespace deconv | } // namespace deconv | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -0,0 +1,214 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/convolution/backward_data/algo.h" | |||||
#include "src/cuda/convolution/backward_data/deconv_int8_helper.cuh" | |||||
#include "src/cuda/convolution_helper/parameter.cuh" | |||||
#include "src/cuda/cutlass/singleton.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
const void* | |||||
ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op( | |||||
const SizeArgs& args) const { | |||||
using namespace cutlass::library; | |||||
auto&& fm = args.filter_meta; | |||||
size_t sh = fm.stride[0], sw = fm.stride[1]; | |||||
cutlass::conv::SpecialOptimizeDesc special_optimization = | |||||
(sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: | |||||
DECONV_DOUBLE_UPSAMPLING | |||||
: cutlass::conv::SpecialOptimizeDesc::NONE; | |||||
LayoutTypeID filter_layout; | |||||
if (m_algo_param.access_size == 16) { | |||||
filter_layout = LayoutTypeID::kTensorCK16RS16; | |||||
} else if (m_algo_param.access_size == 8) { | |||||
filter_layout = LayoutTypeID::kTensorCK8RS8; | |||||
} else { | |||||
megdnn_assert(m_algo_param.access_size == 4, "invalid access_size: %d", | |||||
m_algo_param.access_size); | |||||
filter_layout = LayoutTypeID::kTensorCK4RS4; | |||||
} | |||||
ConvolutionKey key{ | |||||
cutlass::conv::Operator::kDgrad, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNHWC, | |||||
NumericTypeID::kS8, | |||||
filter_layout, | |||||
NumericTypeID::kS8, | |||||
LayoutTypeID::kTensorNHWC, | |||||
NumericTypeID::kS32, | |||||
LayoutTypeID::kTensorNHWC, | |||||
cutlass::conv::ConvType::kConvolution, | |||||
m_algo_param.threadblock_m, | |||||
m_algo_param.threadblock_n, | |||||
m_algo_param.threadblock_k, | |||||
m_algo_param.warp_m, | |||||
m_algo_param.warp_n, | |||||
m_algo_param.warp_k, | |||||
8, | |||||
8, | |||||
16, | |||||
cutlass::epilogue::EpilogueType::kBiasAddLinearCombinationClamp, | |||||
m_algo_param.stage, | |||||
special_optimization, | |||||
false}; | |||||
return (void*)Singleton::get().operation_table.find_op(key); | |||||
} | |||||
bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( | |||||
const SizeArgs& args) const { | |||||
auto&& fm = args.filter_meta; | |||||
if (fm.format != Param::Format::NHWC) | |||||
return false; | |||||
if (!args.grad_layout->is_contiguous() || | |||||
!args.diff_layout->is_contiguous()) { | |||||
return false; | |||||
} | |||||
bool available = true; | |||||
auto src_dtype = args.diff_layout->dtype, | |||||
filter_dtype = args.filter_layout->dtype, | |||||
dst_dtype = args.grad_layout->dtype; | |||||
size_t co = args.diff_layout->operator[](3); | |||||
size_t ci = args.grad_layout->operator[](3); | |||||
available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
filter_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
dst_dtype.enumv() == DTypeEnum::QuantizedS8); | |||||
// TODO support group deconv int8 | |||||
available &= (fm.group == 1); | |||||
// mode must be cross correlation | |||||
available &= !fm.should_flip; | |||||
// mode must be 2D | |||||
available &= fm.spatial_ndim == 2; | |||||
// TODO: support dialtion | |||||
available &= (fm.dilation[0] == 1 && fm.dilation[1] == 1); | |||||
// FIXME: too large filter size is not supported now | |||||
size_t kMaxFilterPixels = | |||||
848 / (m_algo_param.warp_k / m_algo_param.access_size) - 1; | |||||
available &= fm.spatial[0] * fm.spatial[1] <= kMaxFilterPixels; | |||||
// ci should be aligned with 4, and co should be aligned with | |||||
// algo_param.access_size | |||||
available &= ((ci % 4 == 0) && (co % m_algo_param.access_size == 0)); | |||||
available &= (get_available_op(args) != nullptr); | |||||
// only support sm_75 or later, platform should have imma int8 support | |||||
available &= is_compute_capability_required(7, 5); | |||||
return available; | |||||
} | |||||
WorkspaceBundle | |||||
ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_bundle( | |||||
dt_byte* raw_ptr, const SizeArgs& args) const { | |||||
size_t ws_filter = args.filter_layout->span().dist_byte(); | |||||
return WorkspaceBundle{raw_ptr, {ws_filter}}; | |||||
} | |||||
size_t ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm:: | |||||
get_workspace_in_bytes(const SizeArgs& args) const { | |||||
return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
} | |||||
void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||||
const ExecArgs& args) const { | |||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
size_t n = args.diff_layout->operator[](0), | |||||
co = args.diff_layout->operator[](3), | |||||
ho = args.diff_layout->operator[](1), | |||||
wo = args.diff_layout->operator[](2); | |||||
size_t ci = args.grad_layout->operator[](3), | |||||
hi = args.grad_layout->operator[](1), | |||||
wi = args.grad_layout->operator[](2); | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
size_t sh = fm.stride[0], sw = fm.stride[1]; | |||||
size_t ph = fm.padding[0], pw = fm.padding[1]; | |||||
size_t dh = param.dilate_h, dw = param.dilate_w; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
int8_t* filter_ptr = nullptr; | |||||
// TODO: weight preprocess | |||||
{ | |||||
filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | |||||
// reformat filter from nc4hw4 to n4hwc4 | |||||
reorder_filter(args, m_algo_param.access_size, filter_ptr); | |||||
} | |||||
float diff_scale = | |||||
args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | |||||
filter_scale = | |||||
args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||||
grad_scale = | |||||
args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | |||||
// \note these constants of cutlass epilogue will be passed to struct | |||||
// `ConvolutionArguments` by pointer and interpreted as ElementCompute*, | |||||
// a different dtype here results in undefined epilogue behaviors | |||||
float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, | |||||
gamma = 0.f, delta = 0.f; | |||||
using namespace cutlass::library; | |||||
const Operation* op = (const Operation*)get_available_op(args); | |||||
// gcc prints warnings when size_t values are implicitly narrowed to int | |||||
cutlass::conv::Conv2dProblemSize problem_size{ | |||||
int(n), int(hi), int(wi), int(ci), | |||||
int(co), int(fh), int(fw), int(ho), | |||||
int(wo), int(ph), int(pw), int(sh), | |||||
int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; | |||||
cutlass::library::ConvolutionArguments conv_args{ | |||||
problem_size, args.diff_tensor->compatible_ptr<int8_t>(), | |||||
filter_ptr, nullptr, | |||||
nullptr, args.grad_tensor->compatible_ptr<int8_t>(), | |||||
&alpha, &beta, | |||||
&gamma, &delta, | |||||
nullptr, nullptr, | |||||
nullptr, nullptr}; | |||||
cutlass_check(op->run(&conv_args, nullptr, stream)); | |||||
after_kernel_launch(); | |||||
} | |||||
void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( | |||||
const ExecArgs& args, const int interleaved, | |||||
int8_t* reordered_filter) const { | |||||
auto&& fm = args.filter_meta; | |||||
size_t co = args.diff_layout->operator[](3); | |||||
size_t ci = args.grad_layout->operator[](3); | |||||
size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( | |||||
reordered_filter, args.filter_tensor->compatible_ptr<int8_t>(), co, | |||||
ci, fh, fw, interleaved, stream); | |||||
} | |||||
void ConvolutionBackwardDataImpl::AlgoPack::fill_int8_imma_algos() { | |||||
using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam; | |||||
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); | |||||
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); | |||||
int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); | |||||
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); | |||||
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); | |||||
int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -99,6 +99,7 @@ public: | |||||
class AlgoBFloat16; | class AlgoBFloat16; | ||||
class AlgoInt8NCHW4DotProdImplicitGemm; | class AlgoInt8NCHW4DotProdImplicitGemm; | ||||
class AlgoInt8NCHWDotProdImplicitGemm; | class AlgoInt8NCHWDotProdImplicitGemm; | ||||
class AlgoInt8NHWCIMMAImplicitGemm; | |||||
class AlgoPack; | class AlgoPack; | ||||
@@ -60,6 +60,7 @@ void initialize_all_gemm_tensorop884_operations(Manifest& manifest); | |||||
void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | void initialize_all_gemm_tensorop1688_operations(Manifest& manifest); | ||||
void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8816_operations(Manifest& manifest); | ||||
void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | void initialize_all_conv2d_tensorop8832_operations(Manifest& manifest); | ||||
void initialize_all_deconv_tensorop8816_operations(Manifest& manifest); | |||||
#endif | #endif | ||||
void initialize_all(Manifest& manifest) { | void initialize_all(Manifest& manifest) { | ||||
@@ -71,6 +72,7 @@ void initialize_all(Manifest& manifest) { | |||||
initialize_all_gemm_tensorop1688_operations(manifest); | initialize_all_gemm_tensorop1688_operations(manifest); | ||||
initialize_all_conv2d_tensorop8816_operations(manifest); | initialize_all_conv2d_tensorop8816_operations(manifest); | ||||
initialize_all_conv2d_tensorop8832_operations(manifest); | initialize_all_conv2d_tensorop8832_operations(manifest); | ||||
initialize_all_deconv_tensorop8816_operations(manifest); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -100,6 +100,9 @@ enum class LayoutTypeID { | |||||
kTensorNC64HW64, | kTensorNC64HW64, | ||||
kTensorC64RSK64, | kTensorC64RSK64, | ||||
kTensorK4RSC4, | kTensorK4RSC4, | ||||
kTensorCK4RS4, | |||||
kTensorCK8RS8, | |||||
kTensorCK16RS16, | |||||
kInvalid | kInvalid | ||||
}; | }; | ||||
@@ -225,6 +228,7 @@ enum class ThreadblockSwizzleID { | |||||
kConvolutionFpropNCxHWx, | kConvolutionFpropNCxHWx, | ||||
kConvolutionFpropTrans, | kConvolutionFpropTrans, | ||||
kConvolutionDgradNCxHWx, | kConvolutionDgradNCxHWx, | ||||
kConvolutionDgradTrans, | |||||
kInvalid | kInvalid | ||||
}; | }; | ||||
@@ -340,6 +340,21 @@ struct LayoutMap<cutlass::layout::TensorKxRSCx<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | static LayoutTypeID const kId = LayoutTypeID::kTensorK4RSC4; | ||||
}; | }; | ||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCKxRSx<4>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorCK4RS4; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCKxRSx<8>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorCK8RS8; | |||||
}; | |||||
template <> | |||||
struct LayoutMap<cutlass::layout::TensorCKxRSx<16>> { | |||||
static LayoutTypeID const kId = LayoutTypeID::kTensorCK16RS16; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
template <typename T> | template <typename T> | ||||
@@ -556,6 +571,13 @@ struct ThreadblockSwizzleMap< | |||||
ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | ThreadblockSwizzleID::kConvolutionDgradNCxHWx; | ||||
}; | }; | ||||
template <> | |||||
struct ThreadblockSwizzleMap< | |||||
conv::threadblock::ConvolutionDgradTransThreadblockSwizzle> { | |||||
static ThreadblockSwizzleID const kId = | |||||
ThreadblockSwizzleID::kConvolutionDgradTrans; | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||
template <typename Element, typename Layout> | template <typename Element, typename Layout> | ||||
@@ -533,7 +533,10 @@ static struct { | |||||
{LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, | {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, | ||||
{LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, | {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, | ||||
{LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, | {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, | ||||
{LayoutTypeID::kTensorK4RSC4, "k4rsC4"}, | |||||
{LayoutTypeID::kTensorK4RSC4, "k4rsc4"}, | |||||
{LayoutTypeID::kTensorCK4RS4, "ck4rs4"}, | |||||
{LayoutTypeID::kTensorCK8RS8, "ck8rs8"}, | |||||
{LayoutTypeID::kTensorCK16RS16, "ck16rs16"}, | |||||
{LayoutTypeID::kUnknown, "*"}, | {LayoutTypeID::kUnknown, "*"}, | ||||
{LayoutTypeID::kInvalid, nullptr}}; | {LayoutTypeID::kInvalid, nullptr}}; | ||||
@@ -1499,6 +1502,8 @@ static struct { | |||||
ThreadblockSwizzleID::kConvolutionFpropTrans}, | ThreadblockSwizzleID::kConvolutionFpropTrans}, | ||||
{"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", | {"convolution_dgrad_ncxhwx", "ConvolutionDgradNCxHWxThreadblockSwizzle", | ||||
ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, | ThreadblockSwizzleID::kConvolutionDgradNCxHWx}, | ||||
{"convolution_dgrad_ncxhwx", "ConvolutionDgradTransThreadblockSwizzle", | |||||
ThreadblockSwizzleID::kConvolutionDgradTrans}, | |||||
}; | }; | ||||
/// Converts a ThreadblockSwizzleID enumerant to a string | /// Converts a ThreadblockSwizzleID enumerant to a string | ||||
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file dnn/src/cuda/memory_utils.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#if MEGDNN_CC_CUDA | |||||
#pragma once | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( | |||||
const int src0, const int src1, const int src2, const int src3, | |||||
int& dst0, int& dst1, int& dst2, int& dst3) { | |||||
int dst01_lo = __byte_perm(src0, src1, 0x5140); | |||||
int dst01_hi = __byte_perm(src0, src1, 0x7362); | |||||
int dst23_lo = __byte_perm(src2, src3, 0x5140); | |||||
int dst23_hi = __byte_perm(src2, src3, 0x7362); | |||||
dst0 = __byte_perm(dst01_lo, dst23_lo, 0x5410); | |||||
dst1 = __byte_perm(dst01_lo, dst23_lo, 0x7632); | |||||
dst2 = __byte_perm(dst01_hi, dst23_hi, 0x5410); | |||||
dst3 = __byte_perm(dst01_hi, dst23_hi, 0x7632); | |||||
} | |||||
template <uint32_t interleaved, typename vec_type> | |||||
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( | |||||
const int src[interleaved], vec_type (&dst)[4]); | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( | |||||
const int src[4], int (&dst)[4]) { | |||||
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], | |||||
dst[2], dst[3]); | |||||
} | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( | |||||
const int src[8], int2 (&dst)[4]) { | |||||
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, | |||||
dst[2].x, dst[3].x); | |||||
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, | |||||
dst[2].y, dst[3].y); | |||||
} | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( | |||||
const int src[16], int4 (&dst)[4]) { | |||||
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, | |||||
dst[2].x, dst[3].x); | |||||
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, | |||||
dst[2].y, dst[3].y); | |||||
transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, | |||||
dst[1].z, dst[2].z, dst[3].z); | |||||
transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, | |||||
dst[1].w, dst[2].w, dst[3].w); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -469,7 +469,6 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { | |||||
return args; | return args; | ||||
} | } | ||||
std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
param::Convolution cur_param; | param::Convolution cur_param; | ||||
@@ -511,6 +510,46 @@ std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | |||||
return args; | return args; | ||||
} | } | ||||
std::vector<TestArg> convolution::get_args_int8_nhwc_conv_bwd_data() { | |||||
std::vector<TestArg> args; | |||||
param::Convolution cur_param; | |||||
// clang-format off | |||||
for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) { | |||||
for (size_t b : {64, 16}) { | |||||
for (size_t ic : {16, 32}) { | |||||
for (size_t oc : {16, 32}) { | |||||
for (size_t h : {8}) { | |||||
for (size_t w : {8, 11}) { | |||||
for (size_t kernel_size : {3, 4, 5, 7}) { | |||||
for (int p : {0, static_cast<int>(kernel_size / 2)}) { | |||||
for (size_t s : {2}) { | |||||
if (kernel_size >= 7) { | |||||
b = std::min(b, 32_z); | |||||
} | |||||
size_t f = kernel_size; | |||||
cur_param.mode = mode; | |||||
cur_param.format = param::Convolution::Format::NHWC; | |||||
cur_param.sparse = param::Convolution::Sparse::DENSE; | |||||
cur_param.pad_h = cur_param.pad_w = p; | |||||
cur_param.stride_h = cur_param.stride_w = s; | |||||
//! bias channel | |||||
args.emplace_back(cur_param, TensorShape{b, h, w, ic}, | |||||
TensorShape{oc, f, f, ic}); | |||||
} } } } } } } } } | |||||
// clang-format on | |||||
cur_param.pad_h = cur_param.pad_w = 1; | |||||
cur_param.stride_h = cur_param.stride_w = 1; | |||||
args.emplace_back(cur_param, TensorShape{16, 8, 11, 16}, | |||||
TensorShape{16, 3, 3, 16}); | |||||
return args; | |||||
} | |||||
void convolution::test_conv_config_combinations( | void convolution::test_conv_config_combinations( | ||||
int k_size, Handle* handle, bool test_int8, bool test_backward, | int k_size, Handle* handle, bool test_int8, bool test_backward, | ||||
bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) { | bool is_cuda, ConvEPSGetter eps_getter, bool use_io16xc32) { | ||||
@@ -50,6 +50,7 @@ std::vector<TestArg> get_dilated_args(); | |||||
std::vector<TestArg> get_chanwise_args(); | std::vector<TestArg> get_chanwise_args(); | ||||
std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data(); | std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data(); | ||||
std::vector<TestArg> get_args_int8_nchw_conv_bwd_data(); | std::vector<TestArg> get_args_int8_nchw_conv_bwd_data(); | ||||
std::vector<TestArg> get_args_int8_nhwc_conv_bwd_data(); | |||||
//! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter | //! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter | ||||
using ConvEPSGetter = | using ConvEPSGetter = | ||||
@@ -386,6 +386,69 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { | |||||
} | } | ||||
} | } | ||||
#if CUDA_VERSION >= 10020 | |||||
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA) { | |||||
if (!cuda::is_compute_capability_required(7, 5)) { | |||||
printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NHWC_IMMA test as " | |||||
"current device doesn't support\n"); | |||||
return; | |||||
} | |||||
using namespace convolution; | |||||
std::vector<TestArg> args = get_args_int8_nhwc_conv_bwd_data(); | |||||
struct AlgoParam { | |||||
int threadblock_m; | |||||
int threadblock_n; | |||||
int threadblock_k; | |||||
int warp_m; | |||||
int warp_n; | |||||
int warp_k; | |||||
int stage; | |||||
int access_size; | |||||
std::string to_string() { | |||||
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, | |||||
threadblock_n, threadblock_k, warp_m, warp_n, | |||||
warp_k, stage, access_size); | |||||
} | |||||
}; | |||||
std::vector<AlgoParam> all_params; | |||||
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 4}); | |||||
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 8}); | |||||
all_params.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 2, 16}); | |||||
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 4}); | |||||
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 8}); | |||||
all_params.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 16}); | |||||
for (auto algo_param : all_params) { | |||||
Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||||
std::string algo_name(ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", | |||||
algo_param.to_string().c_str())); | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<ConvolutionBackwardData>(algo_name.c_str())); | |||||
checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1); | |||||
for (auto&& arg : args) { | |||||
UniformIntRNG rng(-3, 3); | |||||
auto src = TensorLayout(arg.src, dtype::QuantizedS8{1.2f}); | |||||
auto filter = TensorLayout(arg.filter, dtype::QuantizedS8{1.3f}); | |||||
TensorLayout dst; | |||||
dst.dtype = dtype::QuantizedS8{1.2f}; | |||||
{ | |||||
auto opr = handle_cuda()->create_operator<Convolution>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout(src, filter, dst); | |||||
} | |||||
checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( | |||||
TensorLayoutArray{filter, dst, src}); | |||||
} | |||||
} | |||||
} | |||||
#endif | |||||
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) { | TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_FAILED_CUDNN7_5) { | ||||
// BRAIN-481 failed on architectures 7.0, remove the following if statement, | // BRAIN-481 failed on architectures 7.0, remove the following if statement, | ||||
// when cudnn fixed the problem. | // when cudnn fixed the problem. | ||||