From 319436dd14c375aebc10522b5f1fba82748e0cc1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 May 2021 19:42:43 +0800 Subject: [PATCH] feat(dnn/cuda): add cutlass impls for uint4 x int4 conv bias GitOrigin-RevId: cf4536855ac3faf5a929b1077dac91092b2f008f --- dnn/src/cuda/conv_bias/algo.cpp | 16 +- dnn/src/cuda/conv_bias/algo.h | 51 +++++ .../cuda/conv_bias/cutlass_convolution_wrapper.cu | 130 ++++++++++- .../cuda/conv_bias/cutlass_convolution_wrapper.cuh | 11 +- .../implicit_gemm_int4_int4_nchw64_imma.cpp | 3 +- .../implicit_gemm_uint4_int4_nchw64_imma.cpp | 253 +++++++++++++++++++++ ...v_bias_int4_implicit_gemm_cutlass_wrapper.cuinl | 66 +++++- ...mma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu | 2 +- ...mm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu | 2 +- ..._imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu | 2 +- ...mma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu | 2 +- ...mm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu | 2 +- ..._imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu | 2 +- ...mma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu | 36 +++ ...mm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu | 36 +++ ..._imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu | 36 +++ ...mma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu | 36 +++ ...mm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu | 36 +++ ..._imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu | 36 +++ ...v_bias_int8_implicit_gemm_cutlass_wrapper.cuinl | 7 +- ...mm_dp4a_ncdiv4hw4_128x128x32_64x32x32_hswish.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_id.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_relu.cu | 4 +- ...emm_dp4a_ncdiv4hw4_128x32x32_64x32x32_hswish.cu | 4 +- ...it_gemm_dp4a_ncdiv4hw4_128x32x32_64x32x32_id.cu | 4 +- ..._gemm_dp4a_ncdiv4hw4_128x32x32_64x32x32_relu.cu | 4 +- ...emm_dp4a_ncdiv4hw4_128x64x32_64x32x32_hswish.cu | 4 +- ...it_gemm_dp4a_ncdiv4hw4_128x64x32_64x32x32_id.cu | 4 +- ..._gemm_dp4a_ncdiv4hw4_128x64x32_64x32x32_relu.cu | 4 +- ...mm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_16x64x8_16x64x8_hswish.cu | 4 +- ...licit_gemm_dp4a_ncdiv4hw4_16x64x8_16x64x8_id.cu | 4 +- ...cit_gemm_dp4a_ncdiv4hw4_16x64x8_16x64x8_relu.cu | 4 +- ...p4a_ncdiv4hw4_1x1_128x128x32_64x32x32_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_128x128x32_64x32x32_id.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_128x128x32_64x32x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_1x1_128x32x32_64x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_1x1_128x32x32_64x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_1x1_128x32x32_64x32x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_1x1_128x64x32_64x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_1x1_128x64x32_64x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_1x1_128x64x32_64x32x32_relu.cu | 4 +- ...p4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_relu.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_16x64x8_16x64x8_hswish.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_1x1_16x64x8_16x64x8_id.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_1x1_16x64x8_16x64x8_relu.cu | 4 +- ...dp4a_ncdiv4hw4_1x1_32x128x32_32x64x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_1x1_32x128x32_32x64x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_1x1_32x128x32_32x64x32_relu.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_32x32x32_32x32x32_hswish.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_1x1_32x32x32_32x32x32_id.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_32x32x32_32x32x32_relu.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_32x64x32_32x64x32_hswish.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_1x1_32x64x32_32x64x32_id.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_32x64x32_32x64x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_1x1_64x128x32_64x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_1x1_64x128x32_64x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_1x1_64x128x32_64x32x32_relu.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_64x32x32_64x32x32_hswish.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_1x1_64x32x32_64x32x32_id.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_64x32x32_64x32x32_relu.cu | 4 +- ..._dp4a_ncdiv4hw4_1x1_64x64x32_64x32x32_hswish.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_1x1_64x64x32_64x32x32_id.cu | 4 +- ...mm_dp4a_ncdiv4hw4_1x1_64x64x32_64x32x32_relu.cu | 4 +- ...emm_dp4a_ncdiv4hw4_32x128x32_32x64x32_hswish.cu | 4 +- ...it_gemm_dp4a_ncdiv4hw4_32x128x32_32x64x32_id.cu | 4 +- ..._gemm_dp4a_ncdiv4hw4_32x128x32_32x64x32_relu.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_32x32x32_32x32x32_hswish.cu | 4 +- ...cit_gemm_dp4a_ncdiv4hw4_32x32x32_32x32x32_id.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_32x32x32_32x32x32_relu.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_32x64x32_32x64x32_hswish.cu | 4 +- ...cit_gemm_dp4a_ncdiv4hw4_32x64x32_32x64x32_id.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_32x64x32_32x64x32_relu.cu | 4 +- ...emm_dp4a_ncdiv4hw4_64x128x32_64x32x32_hswish.cu | 4 +- ...it_gemm_dp4a_ncdiv4hw4_64x128x32_64x32x32_id.cu | 4 +- ..._gemm_dp4a_ncdiv4hw4_64x128x32_64x32x32_relu.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_64x32x32_64x32x32_hswish.cu | 4 +- ...cit_gemm_dp4a_ncdiv4hw4_64x32x32_64x32x32_id.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_64x32x32_64x32x32_relu.cu | 4 +- ...gemm_dp4a_ncdiv4hw4_64x64x32_64x32x32_hswish.cu | 4 +- ...cit_gemm_dp4a_ncdiv4hw4_64x64x32_64x32x32_id.cu | 4 +- ...t_gemm_dp4a_ncdiv4hw4_64x64x32_64x32x32_relu.cu | 4 +- ...v4hw4_ncdiv32hw32_128x128x32_64x32x32_hswish.cu | 4 +- ...ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_id.cu | 4 +- ...div4hw4_ncdiv32hw32_128x128x32_64x32x32_relu.cu | 4 +- ...iv4hw4_ncdiv32hw32_128x32x32_64x32x32_hswish.cu | 4 +- ..._ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_id.cu | 4 +- ...cdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_relu.cu | 4 +- ...iv4hw4_ncdiv32hw32_128x64x32_64x32x32_hswish.cu | 4 +- ..._ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_id.cu | 4 +- ...cdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_relu.cu | 4 +- ...4_ncdiv32hw32_1x1_128x128x32_64x32x32_hswish.cu | 4 +- ...v4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_id.cu | 4 +- ...hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_relu.cu | 4 +- ...w4_ncdiv32hw32_1x1_128x32x32_64x32x32_hswish.cu | 4 +- ...iv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_id.cu | 4 +- ...4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_relu.cu | 4 +- ...w4_ncdiv32hw32_1x1_128x64x32_64x32x32_hswish.cu | 4 +- ...iv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_id.cu | 4 +- ...4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_relu.cu | 4 +- ...w4_ncdiv32hw32_1x1_32x128x32_32x64x32_hswish.cu | 4 +- ...iv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_id.cu | 4 +- ...4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_relu.cu | 4 +- ...hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_hswish.cu | 4 +- ...div4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_id.cu | 4 +- ...v4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_relu.cu | 4 +- ...hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_hswish.cu | 4 +- ...div4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_id.cu | 4 +- ...v4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_relu.cu | 4 +- ...w4_ncdiv32hw32_1x1_64x128x32_64x32x32_hswish.cu | 4 +- ...iv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_id.cu | 4 +- ...4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_relu.cu | 4 +- ...hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_hswish.cu | 4 +- ...div4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_id.cu | 4 +- ...v4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_relu.cu | 4 +- ...hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_hswish.cu | 4 +- ...div4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_id.cu | 4 +- ...v4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_relu.cu | 4 +- ...iv4hw4_ncdiv32hw32_32x128x32_32x64x32_hswish.cu | 4 +- ..._ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_id.cu | 4 +- ...cdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_relu.cu | 4 +- ...div4hw4_ncdiv32hw32_32x32x32_32x32x32_hswish.cu | 4 +- ...a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_id.cu | 4 +- ...ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_relu.cu | 4 +- ...div4hw4_ncdiv32hw32_32x64x32_32x64x32_hswish.cu | 4 +- ...a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_id.cu | 4 +- ...ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_relu.cu | 4 +- ...iv4hw4_ncdiv32hw32_64x128x32_64x32x32_hswish.cu | 4 +- ..._ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_id.cu | 4 +- ...cdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_relu.cu | 4 +- ...div4hw4_ncdiv32hw32_64x32x32_64x32x32_hswish.cu | 4 +- ...a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_id.cu | 4 +- ...ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_relu.cu | 4 +- ...div4hw4_ncdiv32hw32_64x64x32_64x32x32_hswish.cu | 4 +- ...a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_id.cu | 4 +- ...ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_relu.cu | 4 +- ...4a_ncdiv4hw4_nchw_128x128x32_64x32x32_hswish.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_128x128x32_64x32x32_id.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_128x128x32_64x32x32_relu.cu | 4 +- ...p4a_ncdiv4hw4_nchw_128x32x32_64x32x32_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_nchw_128x32x32_64x32x32_id.cu | 4 +- ..._dp4a_ncdiv4hw4_nchw_128x32x32_64x32x32_relu.cu | 4 +- ...p4a_ncdiv4hw4_nchw_128x64x32_64x32x32_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_nchw_128x64x32_64x32x32_id.cu | 4 +- ..._dp4a_ncdiv4hw4_nchw_128x64x32_64x32x32_relu.cu | 4 +- ...4a_ncdiv4hw4_nchw_16x128x16_16x128x16_hswish.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_16x128x16_16x128x16_id.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_16x128x16_16x128x16_relu.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_16x64x8_16x64x8_hswish.cu | 4 +- ..._gemm_dp4a_ncdiv4hw4_nchw_16x64x8_16x64x8_id.cu | 4 +- ...emm_dp4a_ncdiv4hw4_nchw_16x64x8_16x64x8_relu.cu | 4 +- ...cdiv4hw4_nchw_1x1_128x128x32_64x32x32_hswish.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_128x128x32_64x32x32_id.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_128x128x32_64x32x32_relu.cu | 4 +- ...ncdiv4hw4_nchw_1x1_128x32x32_64x32x32_hswish.cu | 4 +- ...p4a_ncdiv4hw4_nchw_1x1_128x32x32_64x32x32_id.cu | 4 +- ...a_ncdiv4hw4_nchw_1x1_128x32x32_64x32x32_relu.cu | 4 +- ...ncdiv4hw4_nchw_1x1_128x64x32_64x32x32_hswish.cu | 4 +- ...p4a_ncdiv4hw4_nchw_1x1_128x64x32_64x32x32_id.cu | 4 +- ...a_ncdiv4hw4_nchw_1x1_128x64x32_64x32x32_relu.cu | 4 +- ...cdiv4hw4_nchw_1x1_16x128x16_16x128x16_hswish.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_16x128x16_16x128x16_id.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_16x128x16_16x128x16_relu.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_16x64x8_16x64x8_hswish.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_1x1_16x64x8_16x64x8_id.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_1x1_16x64x8_16x64x8_relu.cu | 4 +- ...ncdiv4hw4_nchw_1x1_32x128x32_32x64x32_hswish.cu | 4 +- ...p4a_ncdiv4hw4_nchw_1x1_32x128x32_32x64x32_id.cu | 4 +- ...a_ncdiv4hw4_nchw_1x1_32x128x32_32x64x32_relu.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_32x32x32_32x32x32_hswish.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_1x1_32x32x32_32x32x32_id.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_32x32x32_32x32x32_relu.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_32x64x32_32x64x32_hswish.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_1x1_32x64x32_32x64x32_id.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_32x64x32_32x64x32_relu.cu | 4 +- ...ncdiv4hw4_nchw_1x1_64x128x32_64x32x32_hswish.cu | 4 +- ...p4a_ncdiv4hw4_nchw_1x1_64x128x32_64x32x32_id.cu | 4 +- ...a_ncdiv4hw4_nchw_1x1_64x128x32_64x32x32_relu.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_64x32x32_64x32x32_hswish.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_1x1_64x32x32_64x32x32_id.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_64x32x32_64x32x32_relu.cu | 4 +- ..._ncdiv4hw4_nchw_1x1_64x64x32_64x32x32_hswish.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_1x1_64x64x32_64x32x32_id.cu | 4 +- ...4a_ncdiv4hw4_nchw_1x1_64x64x32_64x32x32_relu.cu | 4 +- ...p4a_ncdiv4hw4_nchw_32x128x32_32x64x32_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_nchw_32x128x32_32x64x32_id.cu | 4 +- ..._dp4a_ncdiv4hw4_nchw_32x128x32_32x64x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_32x32x32_32x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_nchw_32x32x32_32x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_32x32x32_32x32x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_32x64x32_32x64x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_nchw_32x64x32_32x64x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_32x64x32_32x64x32_relu.cu | 4 +- ...p4a_ncdiv4hw4_nchw_64x128x32_64x32x32_hswish.cu | 4 +- ...mm_dp4a_ncdiv4hw4_nchw_64x128x32_64x32x32_id.cu | 4 +- ..._dp4a_ncdiv4hw4_nchw_64x128x32_64x32x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_64x32x32_64x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_nchw_64x32x32_64x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_64x32x32_64x32x32_relu.cu | 4 +- ...dp4a_ncdiv4hw4_nchw_64x64x32_64x32x32_hswish.cu | 4 +- ...emm_dp4a_ncdiv4hw4_nchw_64x64x32_64x32x32_id.cu | 4 +- ...m_dp4a_ncdiv4hw4_nchw_64x64x32_64x32x32_relu.cu | 4 +- ..._imma_ncdiv32hw32_128x128x64_64x64x64_hswish.cu | 4 +- ...gemm_imma_ncdiv32hw32_128x128x64_64x64x64_id.cu | 4 +- ...mm_imma_ncdiv32hw32_128x128x64_64x64x64_relu.cu | 4 +- ..._imma_ncdiv32hw32_128x256x64_64x64x64_hswish.cu | 4 +- ...gemm_imma_ncdiv32hw32_128x256x64_64x64x64_id.cu | 4 +- ...mm_imma_ncdiv32hw32_128x256x64_64x64x64_relu.cu | 4 +- ...m_imma_ncdiv32hw32_128x64x64_64x32x64_hswish.cu | 4 +- ..._gemm_imma_ncdiv32hw32_128x64x64_64x32x64_id.cu | 4 +- ...emm_imma_ncdiv32hw32_128x64x64_64x32x64_relu.cu | 4 +- ...a_ncdiv32hw32_1x1_128x128x64_64x64x64_hswish.cu | 4 +- ..._imma_ncdiv32hw32_1x1_128x128x64_64x64x64_id.cu | 4 +- ...mma_ncdiv32hw32_1x1_128x128x64_64x64x64_relu.cu | 4 +- ...a_ncdiv32hw32_1x1_128x256x64_64x64x64_hswish.cu | 4 +- ..._imma_ncdiv32hw32_1x1_128x256x64_64x64x64_id.cu | 4 +- ...mma_ncdiv32hw32_1x1_128x256x64_64x64x64_relu.cu | 4 +- ...ma_ncdiv32hw32_1x1_128x64x64_64x32x64_hswish.cu | 4 +- ...m_imma_ncdiv32hw32_1x1_128x64x64_64x32x64_id.cu | 4 +- ...imma_ncdiv32hw32_1x1_128x64x64_64x32x64_relu.cu | 4 +- ...a_ncdiv32hw32_1x1_256x128x64_64x64x64_hswish.cu | 4 +- ..._imma_ncdiv32hw32_1x1_256x128x64_64x64x64_id.cu | 4 +- ...mma_ncdiv32hw32_1x1_256x128x64_64x64x64_relu.cu | 4 +- ...mma_ncdiv32hw32_1x1_32x64x64_32x16x64_hswish.cu | 4 +- ...mm_imma_ncdiv32hw32_1x1_32x64x64_32x16x64_id.cu | 4 +- ..._imma_ncdiv32hw32_1x1_32x64x64_32x16x64_relu.cu | 4 +- ...ma_ncdiv32hw32_1x1_64x128x64_32x64x64_hswish.cu | 4 +- ...m_imma_ncdiv32hw32_1x1_64x128x64_32x64x64_id.cu | 4 +- ...imma_ncdiv32hw32_1x1_64x128x64_32x64x64_relu.cu | 4 +- ...mma_ncdiv32hw32_1x1_64x64x64_32x32x64_hswish.cu | 4 +- ...mm_imma_ncdiv32hw32_1x1_64x64x64_32x32x64_id.cu | 4 +- ..._imma_ncdiv32hw32_1x1_64x64x64_32x32x64_relu.cu | 4 +- ..._imma_ncdiv32hw32_256x128x64_64x64x64_hswish.cu | 4 +- ...gemm_imma_ncdiv32hw32_256x128x64_64x64x64_id.cu | 4 +- ...mm_imma_ncdiv32hw32_256x128x64_64x64x64_relu.cu | 4 +- ...mm_imma_ncdiv32hw32_32x64x64_32x16x64_hswish.cu | 4 +- ...t_gemm_imma_ncdiv32hw32_32x64x64_32x16x64_id.cu | 4 +- ...gemm_imma_ncdiv32hw32_32x64x64_32x16x64_relu.cu | 4 +- ...m_imma_ncdiv32hw32_64x128x64_32x64x64_hswish.cu | 4 +- ..._gemm_imma_ncdiv32hw32_64x128x64_32x64x64_id.cu | 4 +- ...emm_imma_ncdiv32hw32_64x128x64_32x64x64_relu.cu | 4 +- ...mm_imma_ncdiv32hw32_64x64x64_32x32x64_hswish.cu | 4 +- ...t_gemm_imma_ncdiv32hw32_64x64x64_32x32x64_id.cu | 4 +- ...gemm_imma_ncdiv32hw32_64x64x64_32x32x64_relu.cu | 4 +- ...v32hw32_ncdiv4hw4_128x128x64_64x64x64_hswish.cu | 4 +- ...ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_id.cu | 4 +- ...div32hw32_ncdiv4hw4_128x128x64_64x64x64_relu.cu | 4 +- ...v32hw32_ncdiv4hw4_128x256x64_64x64x64_hswish.cu | 4 +- ...ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_id.cu | 4 +- ...div32hw32_ncdiv4hw4_128x256x64_64x64x64_relu.cu | 4 +- ...iv32hw32_ncdiv4hw4_128x64x64_64x32x64_hswish.cu | 4 +- ..._ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_id.cu | 4 +- ...cdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_relu.cu | 4 +- ...w32_ncdiv4hw4_1x1_128x128x64_64x64x64_hswish.cu | 4 +- ...v32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_id.cu | 4 +- ...2hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_relu.cu | 4 +- ...w32_ncdiv4hw4_1x1_128x256x64_64x64x64_hswish.cu | 4 +- ...v32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_id.cu | 4 +- ...2hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_relu.cu | 4 +- ...hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_hswish.cu | 4 +- ...iv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_id.cu | 4 +- ...32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_relu.cu | 4 +- ...w32_ncdiv4hw4_1x1_256x128x64_64x64x64_hswish.cu | 4 +- ...v32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_id.cu | 4 +- ...2hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_relu.cu | 4 +- ...2hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_hswish.cu | 4 +- ...div32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_id.cu | 4 +- ...v32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_relu.cu | 4 +- ...hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_hswish.cu | 4 +- ...iv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_id.cu | 4 +- ...32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_relu.cu | 4 +- ...2hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_hswish.cu | 4 +- ...div32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_id.cu | 4 +- ...v32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_relu.cu | 4 +- ...v32hw32_ncdiv4hw4_256x128x64_64x64x64_hswish.cu | 4 +- ...ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_id.cu | 4 +- ...div32hw32_ncdiv4hw4_256x128x64_64x64x64_relu.cu | 4 +- ...div32hw32_ncdiv4hw4_32x64x64_16x32x64_hswish.cu | 4 +- ...a_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_id.cu | 4 +- ...ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_relu.cu | 4 +- ...iv32hw32_ncdiv4hw4_64x128x64_32x64x64_hswish.cu | 4 +- ..._ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_id.cu | 4 +- ...cdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_relu.cu | 4 +- ...div32hw32_ncdiv4hw4_64x64x64_32x32x64_hswish.cu | 4 +- ...a_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_id.cu | 4 +- ...ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_relu.cu | 4 +- dnn/src/cuda/conv_bias/opr_impl.h | 1 + dnn/test/cuda/conv_bias_int8.cpp | 2 +- 292 files changed, 1292 insertions(+), 556 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp mode change 120000 => 100644 dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu create mode 100644 dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index ac7a55d8..a1baa9a3 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -87,6 +87,9 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { for (auto&& algo : int4_int4_nchw64_imma) { all_algos.push_back(&algo); } + for (auto&& algo : uint4_int4_nchw64_imma) { + all_algos.push_back(&algo); + } #endif #endif fill_dp4a_algos(); @@ -231,8 +234,17 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { { using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; - int4_int4_nchw64_imma.emplace_back(AlgoParam{128, 128, 128, 64, 64, 128}); - int4_int4_nchw64_imma.emplace_back(AlgoParam{256, 128, 128, 64, 64, 128}); + int4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 128, 128, 64, 64, 128}); + int4_int4_nchw64_imma.emplace_back( + AlgoParam{256, 128, 128, 64, 64, 128}); + } + { + using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; + uint4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 128, 128, 64, 64, 128}); + uint4_int4_nchw64_imma.emplace_back( + AlgoParam{256, 128, 128, 64, 64, 128}); } #endif } diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 00d1d92f..aa74e622 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -62,6 +62,7 @@ public: CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8, CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8, CUDA_IMPLICIT_GEMM_IMMA_NCHW64_INT4_INT4, + CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4, CUDA_BFLOAT16, CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8, CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8, @@ -810,6 +811,55 @@ private: AlgoParam m_algo_param; std::string m_name; }; + +class ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm final + : public AlgoBase { +public: + struct AlgoParam { + int threadblock_m; + int threadblock_n; + int threadblock_k; + int warp_m; + int warp_n; + int warp_k; + }; + AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) + : m_algo_param{algo_param} { + m_name = ConvBias::algo_name( + ssprintf("UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), + ConvBias::DirectParam{}); + } + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + static std::string to_string(AlgoParam algo_param); + size_t get_preprocess_workspace_in_bytes( + const SizeArgs& args) const override; + SmallVector deduce_preprocessed_filter_layout( + const SizeArgs& args) const override; + void exec_preprocess(const ExecArgs& args) const override; + MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW64_UINT4_INT4) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_algo_param, ret); + return ret; + } + +private: + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, + const SizeArgs& args) const; + void reorder_filter_bias(const ExecArgs& args, void* reduce_filter, + void* reordered_filter, + void* reordered_bias) const; + AlgoParam m_algo_param; + std::string m_name; +}; #endif class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { @@ -868,6 +918,7 @@ public: #if CUDA_VERSION >= 10020 std::vector int8_nchw32_imma; std::vector int4_int4_nchw64_imma; + std::vector uint4_int4_nchw64_imma; #endif std::vector> gconv_refhold; AlgoBFloat16 bfloat16; diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu index 82884e07..cb77f617 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu @@ -662,7 +662,7 @@ INST(true); INST(false); #undef INST -/* ====== cutlass kernel wrapper for int4 nchw64 layout ====== */ +/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ #if MEGDNN_TEGRA_X1 template @@ -783,4 +783,132 @@ void megdnn::cuda::cutlass_wrapper:: INST(true); #undef INST +/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ + +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( + const uint8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const uint8_t* /* d_z */, + uint8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* delta */, + float /* theta */, float /* scale */, + uint8_t /* src_zero_point */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( + const uint8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float delta, float theta, float scale, uint8_t src_zero_point, + const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, + cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ + using Convolution = cutlass::conv::device::Convolution< \ + cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ + cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ + ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropNCxHWxThreadblockSwizzle, \ + 2, 32, 32, NeedLoadFromConstMem>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + reinterpret_cast(d_src), \ + reinterpret_cast(d_filter), d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, \ + conv_param, epilogue, stream, {src_zero_point}); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = cutlass::uint4b_t; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + delta + theta}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + 0, delta, theta}; + DISPATCH_KERNEL; + } + case NonlineMode::H_SWISH: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationHSwishClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + scale, delta, theta}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ + need_load_from_const_mem>( \ + const uint8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float delta, float theta, float scale, \ + uint8_t src_zero_point, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, cudaStream_t stream); +INST(true); +#undef INST + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh index 3c9a3484..c97f2bc7 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh @@ -29,7 +29,7 @@ void cutlass_convolution_wrapper( typename Convolution::ElementDst* d_dst, int* workspace, typename Convolution::ConvolutionParameter const& conv_param, typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream); + cudaStream_t stream, typename Convolution::ExtraParam extra_param = {}); template void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( @@ -85,6 +85,15 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream); +template +void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( + const uint8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, + const uint8_t* d_z, uint8_t* d_dst, int* workspace, + const convolution::ConvParam& param, uint32_t nonlinear_mode, + float alpha, float beta, float gamma, float delta, float theta, + float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, cudaStream_t stream); + } // namespace cutlass_wrapper } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp index 4a8a4cc6..50ead151 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma.cpp + * \file dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -77,7 +77,6 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec( const ExecArgs& args) const { - using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; size_t n = args.src_layout->operator[](0), diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp new file mode 100644 index 00000000..6ca37306 --- /dev/null +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp @@ -0,0 +1,253 @@ +/** + * \file dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./algo.h" +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#include "src/cuda/conv_bias/reduce_filter.cuh" +#include "src/cuda/convolution_helper/parameter.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace convolution; + +#if CUDA_VERSION >= 10020 +bool ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::is_available( + const SizeArgs& args) const { + if (args.bias_layout->ndim <= 0) + return false; + + using Param = param::ConvBias; + using Format = Param::Format; + using Sparse = Param::Sparse; + using Mode = Param::Mode; + using NonlineMode = megdnn::param::ConvBias::NonlineMode; + + auto&& param = args.opr->param(); + + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) + return false; + + if (param.format != Format::NCHW64 || param.sparse != Sparse::DENSE || + param.mode != Mode::CROSS_CORRELATION) + return false; + + if (param.nonlineMode != NonlineMode::IDENTITY && + param.nonlineMode != NonlineMode::RELU && + param.nonlineMode != NonlineMode::H_SWISH) + return false; + + if (args.src_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm || + args.filter_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || + args.bias_layout->dtype.enumv() != DTypeEnum::QuantizedS32 || + args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm) + return false; + + if (!is_compute_capability_required(7, 5)) + return false; + + return true; +} + +WorkspaceBundle +ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_workspace_bundle( + dt_byte* raw_ptr, const SizeArgs& args) const { + if (args.preprocessed_filter) { + return WorkspaceBundle{raw_ptr, {}}; + } else { + size_t ws_filter = args.filter_layout->span().dist_byte(), + ws_bias = args.bias_layout->span().dist_byte(), + ws_reduce_filter = get_preprocess_workspace_in_bytes(args); + return WorkspaceBundle{raw_ptr, + {ws_filter + ws_bias + ws_reduce_filter}}; + } +} + +size_t ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: + get_workspace_in_bytes(const SizeArgs& args) const { + return get_workspace_bundle(nullptr, args).total_size_in_bytes(); +} + +void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec( + const ExecArgs& args) const { + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + size_t n = args.src_layout->operator[](0), + ci = args.src_layout->operator[](1) * 64, + hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t co = args.dst_layout->operator[](1) * 64, + ho = args.dst_layout->operator[](2), + wo = args.dst_layout->operator[](3); + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR + auto&& stream = cuda_stream(args.opr->handle()); + + void* filter_ptr = nullptr; + void* bias_ptr = nullptr; + if (args.preprocessed_filter) { + megdnn_assert(args.preprocessed_filter->tensors.size() == 2); + filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; + bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; + } else { + // reorder filter and bias + filter_ptr = reinterpret_cast(args.workspace.raw_ptr); + bias_ptr = + reinterpret_cast(args.workspace.raw_ptr + + args.filter_layout->span().dist_byte()); + void* reduce_filter_ptr = + reinterpret_cast(args.workspace.raw_ptr + + args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte()); + reorder_filter_bias(args, reduce_filter_ptr, filter_ptr, bias_ptr); + } + + ConvParam kern_param; + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, + kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, + kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.fw = fw; + + float src_scale = + args.src_layout->dtype.param().scale, + filter_scale = + args.filter_layout->dtype.param().scale, + bias_scale = + args.bias_layout->dtype.param().scale, + dst_scale = + args.dst_layout->dtype.param().scale; + + uint8_t src_zero = args.src_layout->dtype.param() + .zero_point, + dst_zero = args.dst_layout->dtype.param() + .zero_point; + + float alpha = src_scale * filter_scale / dst_scale; + float beta = bias_scale / dst_scale; + float gamma = 0.f; + float delta = 0.f; + float theta = dst_zero; + + uint8_t* z_dev_ptr = nullptr; + if (args.z_layout->ndim > 0) { + z_dev_ptr = reinterpret_cast(args.z_tensor->raw_ptr); + float z_scale = + args.z_layout->dtype.param().scale; + uint8_t z_zero = + args.z_layout->dtype.param().zero_point; + gamma = z_scale / dst_scale; + delta = -z_zero * gamma; + } + + uint32_t nonlinear_mode = static_cast(param.nonlineMode); + + cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< + true>( + reinterpret_cast(args.src_tensor->raw_ptr), + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), z_dev_ptr, + reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, + dst_scale, src_zero, + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); +} + +std::string ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::to_string( + AlgoParam algo_param) { + return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, + algo_param.threadblock_n, algo_param.threadblock_k, + algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); +} + +size_t ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: + get_preprocess_workspace_in_bytes(const SizeArgs& args) const { + size_t co = args.filter_layout->operator[](0), + ci = args.filter_layout->operator[](1) * 64, + fh = args.filter_layout->operator[](2), + fw = args.filter_layout->operator[](3); + size_t ws_size_reduce_filter = co * sizeof(int32_t); + size_t A = co, B = ci * fh * fw / 8, C = 1; + ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); + return ws_size_reduce_filter; +} + +SmallVector ConvBiasForwardImpl:: + AlgoUInt4Int4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( + const SizeArgs& args) const { + return {args.filter_layout->collapse_contiguous(), + args.bias_layout->collapse_contiguous()}; +} + +void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( + const ExecArgs& args) const { + megdnn_assert(args.preprocessed_filter->tensors.size() == 2); + reorder_filter_bias(args, args.workspace.raw_ptr, + args.preprocessed_filter->tensors[0].raw_ptr, + args.preprocessed_filter->tensors[1].raw_ptr); +} + +void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: + reorder_filter_bias(const ExecArgs& args, void* reduce_filter, + void* reordered_filter, + void* reordered_bias) const { + auto&& param = args.opr->param(); + auto&& fm = args.filter_meta; + size_t n = args.src_layout->operator[](0), + ci = args.src_layout->operator[](1) * 64, + hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t co = args.dst_layout->operator[](1) * 64, + ho = args.dst_layout->operator[](2), + wo = args.dst_layout->operator[](3); + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR; + auto&& stream = cuda_stream(args.opr->handle()); + + // filter: KCRS64 => CRSK64 + TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; + src.init_contiguous_stride(); + TensorLayout dst = src; + dst.stride[0] = 64; + dst.stride[1] = co * fh * fw * 64; + dst.stride[2] = co * fw * 64; + dst.stride[3] = co * 64; + dst.stride[4] = 1; + TensorND ts_src, ts_dst; + ts_src.raw_ptr = args.filter_tensor->raw_ptr; + ts_src.layout = src; + ts_dst.raw_ptr = reordered_filter; + ts_dst.layout = dst; + auto&& transpose = args.opr->handle()->create_operator(); + transpose->exec(ts_src, ts_dst); + + // reduce filter and update bias + int32_t* workspace = reinterpret_cast(reordered_bias) + + args.bias_layout->span().dist_byte(); + int src_zero_point = + args.src_tensor->layout.dtype.param() + .zero_point; + do_dispatch_reduce_filter_and_update_bias_4bit( + reinterpret_cast(args.filter_tensor->raw_ptr), + args.bias_tensor->compatible_ptr(), co, ci * fh * fw / 8, + reinterpret_cast(reordered_bias), workspace, + src_zero_point, stream); +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl deleted file mode 120000 index e1100ca5..00000000 --- a/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl +++ /dev/null @@ -1 +0,0 @@ -../int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl \ No newline at end of file diff --git a/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl new file mode 100644 index 00000000..53da89de --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl @@ -0,0 +1,65 @@ +/** + * \file + * dnn/src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "cutlass/convolution/device/convolution.h" +#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +template +void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param) { + typename Convolution::TensorRefSrc tensor_src{ + const_cast(d_src), + Convolution::LayoutSrc::packed( + {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; + typename Convolution::TensorRefFilter tensor_filter{ + const_cast(d_filter), + Convolution::LayoutFilter::packed( + {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; + typename Convolution::TensorRefBias tensor_bias{ + const_cast(d_bias), + Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; + typename Convolution::TensorRefDst tensor_z{ + const_cast(d_z), + Convolution::LayoutDst::packed( + {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; + typename Convolution::TensorRefDst tensor_dst{ + d_dst, + Convolution::LayoutDst::packed( + {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; + typename Convolution::Arguments arguments{conv_param, + tensor_src.non_const_ref(), + tensor_filter.non_const_ref(), + tensor_bias.non_const_ref(), + tensor_z.non_const_ref(), + tensor_dst.non_const_ref(), + epilogue, + {}, + {}, + extra_param}; + Convolution conv_op; + cutlass_check(conv_op.initialize(arguments, workspace)); + cutlass_check(conv_op(stream)); + after_kernel_launch(); +} + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu index 9ab481d0..48e4a9d1 100644 --- a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_hswish.cu @@ -31,6 +31,6 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu new file mode 100644 index 00000000..53b7468b --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_id.cu @@ -0,0 +1,36 @@ +#if !MEGDNN_TEGRA_X1 +// generated by gen_cuda_conv_bias_int4_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" + +using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu new file mode 100644 index 00000000..84bcdacf --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_128x128x128_64x64x128_relu.cu @@ -0,0 +1,36 @@ +#if !MEGDNN_TEGRA_X1 +// generated by gen_cuda_conv_bias_int4_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" + +using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu new file mode 100644 index 00000000..52d2af3f --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_hswish.cu @@ -0,0 +1,36 @@ +#if !MEGDNN_TEGRA_X1 +// generated by gen_cuda_conv_bias_int4_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" + +using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu new file mode 100644 index 00000000..e60c5c2b --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_id.cu @@ -0,0 +1,36 @@ +#if !MEGDNN_TEGRA_X1 +// generated by gen_cuda_conv_bias_int4_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" + +using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu new file mode 100644 index 00000000..b8fb14c6 --- /dev/null +++ b/dnn/src/cuda/conv_bias/int4/kimpl/conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64_256x128x128_64x64x128_relu.cu @@ -0,0 +1,36 @@ +#if !MEGDNN_TEGRA_X1 +// generated by gen_cuda_conv_bias_int4_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/conv_bias/int4/conv_bias_int4_implicit_gemm_cutlass_wrapper.cuinl" + +using LayoutSrc = cutlass::layout::TensorNCxHWx<64>; +using LayoutFilter = cutlass::layout::TensorCxRSKx<64>; +using LayoutDst = cutlass::layout::TensorNCxHWx<64>; +using ThreadBlockShape = cutlass::gemm::GemmShape<256, 128, 128>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; +using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; +using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp< + cutlass::uint4b_t, 16, int32_t, int32_t, float>; +using Convolution = cutlass::conv::device::Convolution< + cutlass::uint4b_t, LayoutSrc, cutlass::int4b_t, LayoutFilter, cutlass::uint4b_t, + LayoutDst, int32_t, LayoutDst, int32_t, + cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, + cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, + 2, 32, 32, true, + cutlass::arch::OpMultiplyAddSaturate>; +template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, + int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param); +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl index cf20f616..9f09ce41 100644 --- a/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl +++ b/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl @@ -26,7 +26,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( typename Convolution::ElementDst* d_dst, int* workspace, typename Convolution::ConvolutionParameter const& conv_param, typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream) { + cudaStream_t stream, typename Convolution::ExtraParam extra_param) { typename Convolution::TensorRefSrc tensor_src{ const_cast(d_src), Convolution::LayoutSrc::packed( @@ -52,7 +52,10 @@ void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( tensor_bias.non_const_ref(), tensor_z.non_const_ref(), tensor_dst.non_const_ref(), - epilogue}; + epilogue, + {}, + {}, + extra_param}; Convolution conv_op; cutlass_check(conv_op.initialize(arguments, workspace)); cutlass_check(conv_op(stream)); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_hswish.cu index aede9980..0e75dbb0 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_128x128x32_64x32x32_hswish.cu @@ -1,5 +1,5 @@ #if !MEGDNN_TEGRA_X1 -// generated by gen_cuda_conv_bias_kern_impls.py +// generated by gen_cuda_conv_bias_int8_kern_impls.py // ignore warning of cutlass #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" @@ -31,6 +31,6 @@ template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper> checker( handle_cuda());