Browse Source

refactor(dnn/cuda): refactor cutlass kernel generator for deconv operation

GitOrigin-RevId: 88e962a912
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
b33217d8f0
11 changed files with 281 additions and 181 deletions
  1. +1
    -1
      dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl
  2. +56
    -0
      dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4.cu
  3. +56
    -0
      dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4.cu
  4. +56
    -0
      dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4.cu
  5. +56
    -0
      dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4.cu
  6. +56
    -0
      dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4.cu
  7. +0
    -36
      dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu
  8. +0
    -36
      dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x64x16_id.cu
  9. +0
    -36
      dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x64x8_16x64x8_id.cu
  10. +0
    -36
      dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_32x128x32_32x64x32_id.cu
  11. +0
    -36
      dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_64x128x32_64x32x32_id.cu

dnn/src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl → dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl
* dnn/src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.

+ 56
- 0
dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4.cu View File

@@ -0,0 +1,56 @@

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x128x16_1_nc4hw4" generated by cutlass generator
using Deconvolution =
typename cutlass::conv::device::Deconvolution<
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int8_t,
cutlass::layout::TensorKxRSCx<4>,
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm61,
cutlass::gemm::GemmShape<16, 128, 16>,
cutlass::gemm::GemmShape<16, 128, 16>,
cutlass::gemm::GemmShape<1, 1, 4>,
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t,
4,
int32_t,
int32_t,
float
>,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
1,
4,
8,
true,
cutlass::arch::OpMultiplyAdd>;



template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);


#pragma GCC diagnostic pop
#endif

+ 56
- 0
dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4.cu View File

@@ -0,0 +1,56 @@

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_s8_idgrad_identity_s8_16x128x16_16x64x16_2_nc4hw4" generated by cutlass generator
using Deconvolution =
typename cutlass::conv::device::Deconvolution<
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int8_t,
cutlass::layout::TensorKxRSCx<4>,
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm61,
cutlass::gemm::GemmShape<16, 128, 16>,
cutlass::gemm::GemmShape<16, 64, 16>,
cutlass::gemm::GemmShape<1, 1, 4>,
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t,
4,
int32_t,
int32_t,
float
>,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2,
4,
4,
true,
cutlass::arch::OpMultiplyAdd>;



template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);


#pragma GCC diagnostic pop
#endif

+ 56
- 0
dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4.cu View File

@@ -0,0 +1,56 @@

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_s8_idgrad_identity_s8_16x64x8_16x64x8_2_nc4hw4" generated by cutlass generator
using Deconvolution =
typename cutlass::conv::device::Deconvolution<
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int8_t,
cutlass::layout::TensorKxRSCx<4>,
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm61,
cutlass::gemm::GemmShape<16, 64, 8>,
cutlass::gemm::GemmShape<16, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 4>,
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t,
4,
int32_t,
int32_t,
float
>,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2,
4,
4,
true,
cutlass::arch::OpMultiplyAdd>;



template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);


#pragma GCC diagnostic pop
#endif

+ 56
- 0
dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4.cu View File

@@ -0,0 +1,56 @@

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_s8_idgrad_identity_s8_32x128x32_32x64x32_2_nc4hw4" generated by cutlass generator
using Deconvolution =
typename cutlass::conv::device::Deconvolution<
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int8_t,
cutlass::layout::TensorKxRSCx<4>,
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm61,
cutlass::gemm::GemmShape<32, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>,
cutlass::gemm::GemmShape<1, 1, 4>,
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t,
4,
int32_t,
int32_t,
float
>,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2,
4,
16,
true,
cutlass::arch::OpMultiplyAdd>;



template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);


#pragma GCC diagnostic pop
#endif

+ 56
- 0
dnn/src/cuda/convolution/backward_data/int8/kimpl/cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4.cu View File

@@ -0,0 +1,56 @@

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/implicit_gemm_deconv_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_s8_idgrad_identity_s8_64x128x32_64x32x32_2_nc4hw4" generated by cutlass generator
using Deconvolution =
typename cutlass::conv::device::Deconvolution<
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int8_t,
cutlass::layout::TensorKxRSCx<4>,
int8_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::layout::TensorNCxHWx<4>,
int32_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm61,
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<64, 32, 32>,
cutlass::gemm::GemmShape<1, 1, 4>,
cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t,
4,
int32_t,
int32_t,
float
>,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2,
4,
16,
true,
cutlass::arch::OpMultiplyAdd>;



template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);


#pragma GCC diagnostic pop
#endif

+ 0
- 36
dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu View File

@@ -1,36 +0,0 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorKxRSCx<4>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 128, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Deconvolution = cutlass::conv::device::Deconvolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutDst, int32_t, LayoutDst, int32_t,
cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
1, 4, 8, true,
cutlass::arch::OpMultiplyAdd>;
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 0
- 36
dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x64x16_id.cu View File

@@ -1,36 +0,0 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorKxRSCx<4>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 128, 16>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 16>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Deconvolution = cutlass::conv::device::Deconvolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutDst, int32_t, LayoutDst, int32_t,
cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2, 4, 4, true,
cutlass::arch::OpMultiplyAdd>;
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 0
- 36
dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_16x64x8_16x64x8_id.cu View File

@@ -1,36 +0,0 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorKxRSCx<4>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<16, 64, 8>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Deconvolution = cutlass::conv::device::Deconvolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutDst, int32_t, LayoutDst, int32_t,
cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2, 4, 4, true,
cutlass::arch::OpMultiplyAdd>;
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 0
- 36
dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_32x128x32_32x64x32_id.cu View File

@@ -1,36 +0,0 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorKxRSCx<4>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Deconvolution = cutlass::conv::device::Deconvolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutDst, int32_t, LayoutDst, int32_t,
cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2, 4, 16, true,
cutlass::arch::OpMultiplyAdd>;
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

+ 0
- 36
dnn/src/cuda/convolution/backward_data/int8/kimpl/deconv_int8_implicit_gemm_dp4a_ncdiv4hw4_64x128x32_64x32x32_id.cu View File

@@ -1,36 +0,0 @@
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/convolution/backward_data/int8/deconv_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorKxRSCx<4>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
int8_t, 4, int32_t, int32_t, float>;
using Deconvolution = cutlass::conv::device::Deconvolution<
int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t,
LayoutDst, int32_t, LayoutDst, int32_t,
cutlass::arch::OpClassSimt, cutlass::arch::Sm61,
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp,
cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle,
2, 4, 16, true,
cutlass::arch::OpMultiplyAdd>;
template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
const typename Deconvolution::ElementSrc* d_src,
const typename Deconvolution::ElementFilter* d_filter,
const typename Deconvolution::ElementBias* d_bias,
const typename Deconvolution::ElementDst* d_z,
typename Deconvolution::ElementDst* d_dst,
int* workspace,
typename Deconvolution::ConvolutionParameter const& conv_param,
typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif

Loading…
Cancel
Save