Browse Source

feat(dnn/cuda): add cudnn frontend api

GitOrigin-RevId: 9b18a57893
HuaHua404-patch-1
Megvii Engine Team 3 years ago
parent
commit
35e9cc9845
20 changed files with 1905 additions and 229 deletions
  1. +5
    -2
      dnn/CMakeLists.txt
  2. +10
    -9
      dnn/include/megdnn/algorithm_cache.h
  3. +2
    -0
      dnn/src/CMakeLists.txt
  4. +30
    -0
      dnn/src/cuda/conv_bias/algo.cpp
  5. +127
    -16
      dnn/src/cuda/conv_bias/algo.h
  6. +22
    -88
      dnn/src/cuda/conv_bias/cudnn_conv.cpp
  7. +87
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
  8. +14
    -113
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
  9. +210
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
  10. +145
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
  11. +98
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
  12. +53
    -1
      dnn/src/cuda/conv_bias/helper.cpp
  13. +9
    -0
      dnn/src/cuda/conv_bias/helper.h
  14. +11
    -0
      dnn/src/cuda/conv_bias/opr_impl.cpp
  15. +7
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  16. +685
    -0
      dnn/src/cuda/cudnn_wrapper_v8.cpp
  17. +70
    -0
      dnn/src/cuda/cudnn_wrapper_v8.h
  18. +15
    -0
      dnn/src/cuda/handle.cpp
  19. +304
    -0
      dnn/test/cuda/conv_v8.cpp
  20. +1
    -0
      third_party/prepare.sh

+ 5
- 2
dnn/CMakeLists.txt View File

@@ -54,9 +54,12 @@ if(MGE_WITH_CUDA)
add_library(cutlass INTERFACE)
target_include_directories(
cutlass
INTERFACE $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include>)
add_library(cudnn-frontend INTERFACE)
target_include_directories(
cudnn-frontend
INTERFACE
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/tools/util/include>)
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include>)
endif()

if(MGE_WITH_TEST)


+ 10
- 9
dnn/include/megdnn/algorithm_cache.h View File

@@ -22,7 +22,16 @@ public:
bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; }
};

struct Key {
struct Hash {
size_t operator()(const KeyStorage& k) const {
size_t h1 = k.k1;
size_t h2 = k.k2;
h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
return h1;
}
};

class Key {
Handle* m_handle;
uint32_t m_opr_type;
const TensorLayout* m_inp_layouts_ptr;
@@ -62,14 +71,6 @@ public:
MGE_WIN_DECLSPEC_FUC void clear();

private:
struct Hash {
size_t operator()(const KeyStorage& k) const {
size_t h1 = k.k1;
size_t h2 = k.k2;
h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2);
return h1;
}
};
std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache;
#if __DEPLOY_ON_XP_SP2__
size_t m_mtx;


+ 2
- 0
dnn/src/CMakeLists.txt View File

@@ -222,6 +222,8 @@ target_link_libraries(megdnn PUBLIC opr_param_defs)
if(MGE_WITH_CUDA)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cutlass>)
target_include_directories(megdnn PRIVATE ${CUDNN_INCLUDE_DIR})

target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cudnn-frontend>)
endif()

if(MGE_WITH_ROCM)


+ 30
- 0
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -14,6 +14,12 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&matmul8x8x32);
non_cudnn_algos.push_back(&batched_matmul);
non_cudnn_algos.push_back(&int1_simple);

#if CUDNN_VERSION > 8004
all_algos.push_back(&cudnn_conv_v8);
all_algos.push_back(&cudnn_conv_bias_activation_v8);
#endif

fill_cudnn_algos();
for (auto&& algo : cudnn_conv_bias_activations) {
all_algos.push_back(&algo);
@@ -169,6 +175,30 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
nonlinear_mode_str.c_str());
}

param::Convolution ConvBiasForwardImpl::AlgoBase::get_param_convolution(
const SizeArgs& args) const {
param::Convolution::Mode mode;
param::Convolution::Sparse sparse = args.filter_meta.group > 1
? param::Convolution::Sparse::GROUP
: param::Convolution::Sparse::DENSE;
if (args.filter_meta.should_flip) {
mode = param::Convolution::Mode::CONVOLUTION;
} else {
mode = param::Convolution::Mode::CROSS_CORRELATION;
}
return param::Convolution{
mode,
args.filter_meta.padding[0],
args.filter_meta.padding[1],
args.filter_meta.stride[0],
args.filter_meta.stride[1],
args.filter_meta.dilation[1],
args.filter_meta.dilation[0],
sparse,
args.filter_meta.format,
args.opr->param().compute_mode};
}

void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
cudnn_conv_bias_activations.push_back(algo.first);


+ 127
- 16
dnn/src/cuda/conv_bias/algo.h View File

@@ -76,6 +76,8 @@ public:
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16,
CUDA_SIMPLE_INT1,
CUDA_CUDNN_CONV_V8,
CUDA_CUDNN_CONVBIAS_V8,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

@@ -157,12 +159,40 @@ public:
}

virtual bool is_cudnn() const { return false; }

param::Convolution get_param_convolution(const SizeArgs& args) const;
};

class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase : public AlgoBase {
public:
AlgoCUDNNConvBiasActivationBase() = default;
virtual ~AlgoCUDNNConvBiasActivationBase() = default;

size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

bool is_cudnn() const override { return true; }

size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override;

protected:
virtual size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void cudnn_execute(
const ExecArgs& args, const Workspace& workspace, float alpha,
float beta) const = 0;

protected:
std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final
: public AlgoCUDNNConvBiasActivationBase {
public:
AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
: AlgoCUDNNConvBiasActivationBase(), m_cudnn_enum(cudnn_enum) {
megdnn_assert(
CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
@@ -171,9 +201,6 @@ public:
"CUDNN:ConvBiasActivation:" + m_attr.name, {});
}

size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
param::Convolution get_param_convolution(const SizeArgs& args) const;
bool is_available(const SizeArgs&) const override;

const char* name() const override { return m_name.c_str(); }
@@ -191,8 +218,6 @@ public:

cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }

bool is_cudnn() const override { return true; }

MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)

std::string param() const override {
@@ -202,11 +227,46 @@ public:
}

private:
std::string m_name;
size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override;
void cudnn_execute(
const ExecArgs& args, const Workspace& workspace, float alpha,
float beta) const override;

private:
cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
};

#if CUDNN_VERSION > 8004
class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8 final
: public AlgoCUDNNConvBiasActivationBase {
public:
AlgoCUDNNConvBiasActivationV8() : AlgoCUDNNConvBiasActivationBase() {
m_name = ConvBiasForward::algo_name<DefaultParam>(
"CUDNN:ConvBiasActivationV8", {});
}
~AlgoCUDNNConvBiasActivationV8() = default;

bool is_available(const SizeArgs& args) const override;

AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
}

const char* name() const override { return m_name.c_str(); }

MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS_V8)

std::string param() const override { return ""; }

private:
size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override;
void cudnn_execute(
const ExecArgs& args, const Workspace& workspace, float alpha,
float beta) const override;
};
#endif

class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
@@ -284,9 +344,34 @@ private:
mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
class ConvBiasForwardImpl::AlgoCUDNNConvBase : public AlgoBase {
public:
AlgoCUDNNConvBase() = default;
virtual ~AlgoCUDNNConvBase() = default;

size_t get_workspace_in_bytes(const SizeArgs& args) const override {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
void exec(const ExecArgs& args) const override;

bool is_cudnn() const override { return true; }

protected:
virtual size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void cudnn_execute(
const ExecArgs& args, const Workspace& workspace) const = 0;

private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;

protected:
std::string m_name;
};

class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoCUDNNConvBase {
public:
AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) {
AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum)
: AlgoCUDNNConvBase(), m_cudnn_enum(cudnn_enum) {
megdnn_assert(
CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
@@ -296,8 +381,6 @@ public:
}

bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

AlgoAttribute attribute() const override {
auto ret = static_cast<AlgoAttribute>(0);
@@ -314,8 +397,6 @@ public:

cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }

bool is_cudnn() const override { return true; }

MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)

std::string param() const override {
@@ -325,12 +406,38 @@ public:
}

private:
std::string m_name;
size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override;
void cudnn_execute(const ExecArgs& args, const Workspace& workspace) const override;

private:
cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
};

WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
#if CUDNN_VERSION > 8004
class ConvBiasForwardImpl::AlgoCUDNNConvV8 final : public AlgoCUDNNConvBase {
public:
AlgoCUDNNConvV8() : AlgoCUDNNConvBase() {
m_name = ConvBiasForward::algo_name<DefaultParam>("CUDNN:ConvolutionV8", {});
}

bool is_available(const SizeArgs& args) const override;

AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
}

const char* name() const override { return m_name.c_str(); }

MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV_V8)

std::string param() const override { return ""; }

private:
size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override;
void cudnn_execute(const ExecArgs& args, const Workspace& workspace) const override;
};
#endif

//! compute small matmul in the kernel
class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
@@ -1140,6 +1247,10 @@ public:
AlgoGroupConvGeneral group;
AlgoBFloat16 bfloat16;
AlgoSimpleInt1 int1_simple;
#if CUDNN_VERSION > 8004
AlgoCUDNNConvV8 cudnn_conv_v8;
AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8;
#endif

AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo);



+ 22
- 88
dnn/src/cuda/conv_bias/cudnn_conv.cpp View File

@@ -56,99 +56,33 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) cons
return status == CUDNN_STATUS_SUCCESS;
}

WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto dst_layout = *args.dst_layout;
SmallVector<size_t> sizes;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
sizes.push_back(dst_layout.span().dist_byte());
}

if (args.z_layout->ndim > 0 &&
args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
auto z_layout = *args.z_layout;
z_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype);
sizes.push_back(z_layout.span().dist_byte());
}

SizeArgs conv_args = args;
conv_args.dst_layout = &dst_layout;

size_t ConvBiasForwardImpl::AlgoCUDNNConv::cudnn_get_workspace_in_bytes(
const SizeArgs& args) const {
CUDNNForwardDescs D;
conv_args.init_conv_desc(D);
args.init_conv_desc(D);

size_t conv_workspace_size;
auto status = cudnnGetConvolutionForwardWorkspaceSize(
conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &conv_workspace_size);
megdnn_assert(
status == CUDNN_STATUS_SUCCESS,
"conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status),
args.to_string().c_str());
sizes.insert(sizes.begin(), conv_workspace_size);
return {ptr, std::move(sizes)};
}

size_t ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
cudnn_check(cudnnGetConvolutionForwardWorkspaceSize(
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum,
&conv_workspace_size));
return conv_workspace_size;
}

void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const {
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
TensorND conv_dst_tensor = *args.dst_tensor;
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout};
conv_dst_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
conv_dst_tensor.layout.dtype);
}

ExecArgs conv_args = args;
conv_args.dst_tensor = &conv_dst_tensor;
conv_args.dst_layout = &conv_dst_tensor.layout;

{
CUDNNForwardDescs D;
conv_args.init_conv_desc(D);
auto conv_workspace = bundle.get_workspace(0);
float alpha = 1.0f, beta = 0.0f;
auto status = cudnnConvolutionForward(
conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
conv_args.src_tensor->raw_ptr(), D.filter_desc.desc,
conv_args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum,
conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc,
conv_args.dst_tensor->raw_ptr());
megdnn_assert(
status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s",
cudnnGetErrorString(status), conv_args.to_string().c_str());
}

if (args.z_layout->ndim > 0) {
auto z_tensor = *args.z_tensor;
if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
z_tensor = TensorND{bundle.get(2), args.z_tensor->layout};
z_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
z_tensor.layout.dtype);
auto typecvt = args.handle->create_operator<TypeCvt>();
typecvt->exec(*args.z_tensor, z_tensor);
}
auto add = args.handle->create_operator<ElemwiseForward>();
add->param().mode = Elemwise::Param::Mode::ADD;
add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor);
}

handle_bias_and_nonlinear(
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
args.bias_tensor);
void ConvBiasForwardImpl::AlgoCUDNNConv::cudnn_execute(
const ExecArgs& args, const Workspace& workspace) const {
CUDNNForwardDescs D;
args.init_conv_desc(D);
float alpha = 1.0f, beta = 0.0f;
auto status = cudnnConvolutionForward(
args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
args.src_tensor->raw_ptr(), D.filter_desc.desc,
args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum,
workspace.raw_ptr, workspace.size, &beta, D.dst_desc.desc,
args.dst_tensor->raw_ptr());
megdnn_assert(
status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s",
cudnnGetErrorString(status), args.to_string().c_str());
}

// vim: syntax=cpp.doxygen

+ 87
- 0
dnn/src/cuda/conv_bias/cudnn_conv_base.cpp View File

@@ -0,0 +1,87 @@
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConvBase::get_workspace_bundle(
void* ptr, const SizeArgs& args) const {
auto dst_layout = *args.dst_layout;
SmallVector<size_t> sizes;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
sizes.push_back(dst_layout.span().dist_byte());
}

if (args.z_layout->ndim > 0 &&
args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
auto z_layout = *args.z_layout;
z_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype);
sizes.push_back(z_layout.span().dist_byte());
}

SizeArgs conv_args = args;
conv_args.dst_layout = &dst_layout;

size_t conv_workspace_size = cudnn_get_workspace_in_bytes(conv_args);

sizes.insert(sizes.begin(), conv_workspace_size);
return {ptr, std::move(sizes)};
}

void ConvBiasForwardImpl::AlgoCUDNNConvBase::exec(const ExecArgs& args) const {
auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
TensorND conv_dst_tensor = *args.dst_tensor;
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout};
conv_dst_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
conv_dst_tensor.layout.dtype);
}

ExecArgs conv_args = args;
conv_args.dst_tensor = &conv_dst_tensor;
conv_args.dst_layout = &conv_dst_tensor.layout;

cudnn_execute(conv_args, bundle.get_workspace(0));

if (args.z_layout->ndim > 0) {
auto z_tensor = *args.z_tensor;
if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
z_tensor = TensorND{bundle.get(2), args.z_tensor->layout};
z_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
z_tensor.layout.dtype);
auto typecvt = args.handle->create_operator<TypeCvt>();
typecvt->exec(*args.z_tensor, z_tensor);
}
auto add = args.handle->create_operator<ElemwiseForward>();
add->param().mode = Elemwise::Param::Mode::ADD;
add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor);
}

handle_bias_and_nonlinear(
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
args.bias_tensor);
}

// vim: syntax=cpp.doxygen

+ 14
- 113
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp View File

@@ -124,10 +124,10 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
// forbits sigmoid for quantized
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
return false;
MEGDNN_FALLTHRU // XXX: why?
case param::ConvBias::NonlineMode::IDENTITY
: if (args.src_layout->dtype.category() ==
DTypeCategory::QUANTIZED) break;
MEGDNN_FALLTHRU; // XXX: why?
case param::ConvBias::NonlineMode::IDENTITY:
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
break;
if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
// cudnn require algo to
// CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
@@ -149,7 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return status == CUDNN_STATUS_SUCCESS;
}

size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes(
size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::cudnn_get_workspace_in_bytes(
const SizeArgs& args) const {
CUDNNForwardDescs D;

@@ -162,85 +162,18 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes(
status == CUDNN_STATUS_SUCCESS,
"conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status),
args.to_string().c_str());
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
// cudnn require bias to be float when executing CONFIG_INT
// convert bias to float if bias is not float at first
workspace_size += sizeof(float) * args.bias_layout->span().dist_elem();
}
return workspace_size;
}

void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec(
const ExecArgs& args) const {
void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::cudnn_execute(
const ExecArgs& args, const Workspace& workspace, float alpha,
float beta) const {
#if CUDNN_MAJOR < 7
megdnn_throw("ConvBias require cudnn 7.0 or higher");
#else
megdnn_assert(cudnnGetVersion() >= 7401);
CUDNNForwardDescs D;
args.init_conv_bias_desc(D);
float alpha = 1.0f, beta = 0.0f;
if (args.z_layout->ndim > 0)
beta = 1.0f;

auto get_scale = [](const DType& dtype) -> float {
megdnn_assert(dtype.category() == DTypeCategory::QUANTIZED);
switch (dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return dtype.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
default:
megdnn_assert_internal(0);
}
};

auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype,
dst_dtype = args.dst_layout->dtype;
megdnn_assert(
(src_dtype.category() == dst_dtype.category()) ||
(src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst_dtype.enumv() == DTypeEnum::Float32));
megdnn_assert(src_dtype.category() == filter_dtype.category());

if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) {
auto expected_bias_scale = get_scale(args.src_layout->dtype) *
get_scale(args.filter_layout->dtype);
alpha = expected_bias_scale;
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED)
alpha /= get_scale(args.dst_layout->dtype);
if (args.z_layout->ndim > 0 &&
args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) {
beta = get_scale(args.z_layout->dtype) / get_scale(args.dst_layout->dtype);
}
if (args.bias_layout->dtype.category() == DTypeCategory::QUANTIZED) {
megdnn_assert(
fabs(expected_bias_scale - get_scale(args.bias_layout->dtype)) <
1e-4);
}
}

auto workspace_ptr = args.workspace.raw_ptr;
auto workspace_size = args.workspace.size;
auto bias_ptr = args.bias_tensor->raw_ptr();
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
auto cvt = args.handle->create_operator<TypeCvt>();
auto float_bias_layout = *args.bias_layout;
auto converted_bias_layout = *args.bias_layout;
converted_bias_layout.dtype = dtype::QuantizedS32(alpha);
float_bias_layout.dtype = dtype::Float32();
auto bias_size_in_bytes = float_bias_layout.span().dist_byte();
megdnn_assert(args.workspace.size >= bias_size_in_bytes);
cvt->exec(
{args.bias_tensor->raw_ptr(), converted_bias_layout},
TensorND{workspace_ptr, float_bias_layout});

bias_ptr = workspace_ptr;
workspace_ptr += bias_size_in_bytes;
workspace_size -= bias_size_in_bytes;
}

cudnnStatus_t status;
if (args.z_layout->ndim == 0) {
@@ -248,55 +181,23 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec(
args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
args.src_tensor->raw_ptr(), D.filter_desc.desc,
args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum,
workspace_ptr, workspace_size, &beta, D.dst_desc.desc,
args.dst_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr,
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr());
workspace.raw_ptr, workspace.size, &beta, D.dst_desc.desc,
args.dst_tensor->raw_ptr(), D.bias_desc.desc,
args.bias_tensor->raw_ptr(), D.conv_desc.act_desc, D.dst_desc.desc,
args.dst_tensor->raw_ptr());
} else {
status = cudnnConvolutionBiasActivationForward(
args.handle->cudnn_handle(), &alpha, D.src_desc.desc,
args.src_tensor->raw_ptr(), D.filter_desc.desc,
args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum,
workspace_ptr, workspace_size, &beta, D.z_desc.desc,
args.z_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr,
workspace.raw_ptr, workspace.size, &beta, D.z_desc.desc,
args.z_tensor->raw_ptr(), D.bias_desc.desc, args.bias_tensor->raw_ptr(),
D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr());
}

megdnn_assert(
status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s, algo %s",
cudnnGetErrorString(status), args.to_string().c_str(), name());
// Noline
switch (args.nonlinear_mode) {
case param::ConvBias::NonlineMode::RELU:
break;
case param::ConvBias::NonlineMode::SIGMOID: {
megdnn_assert(
args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED);
auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
break;
}
case param::ConvBias::NonlineMode::IDENTITY:
break;
case param::ConvBias::NonlineMode::H_SWISH: {
megdnn_assert(
args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED ||
(args.dst_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW));
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) {
auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>();
elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
} else {
auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
}
break;
}
default:
megdnn_throw("unsupported NonlineMode");
}
#endif
}



+ 210
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp View File

@@ -0,0 +1,210 @@
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megdnn/oprs/general.h"

#include "./algo.h"

#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/helper.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::get_workspace_in_bytes(
const SizeArgs& args) const {
auto workspace_size = cudnn_get_workspace_in_bytes(args);

auto&& param = args.opr->param();
if (args.preprocessed_filter == nullptr) {
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
// cudnn require bias to be float when executing CONFIG_INT
// convert bias to float if bias is not float at first
workspace_size += sizeof(float) * args.bias_layout->span().dist_elem();
}
if (param.format == param::ConvBias::Format::NCHW32) {
workspace_size += args.filter_layout->span().dist_byte() +
args.bias_layout->span().dist_byte();
}
}
return workspace_size;
}

void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec(
const ExecArgs& args) const {
float alpha, beta;
std::tie(alpha, beta) = cudnn_get_conv_bias_act_scale_param(
args.src_tensor->layout, args.dst_tensor->layout,
args.filter_tensor->layout, args.bias_tensor->layout,
args.z_tensor->layout);

auto workspace_ptr = args.workspace.raw_ptr;
auto workspace_size = args.workspace.size;
auto bias_ptr = args.bias_tensor->raw_ptr();
TensorND filter_tensor;
TensorND bias_tensor;

auto&& param = args.opr->param();
if (args.preprocessed_filter != nullptr) {
bias_tensor = TensorND{
args.bias_tensor->layout,
args.preprocessed_filter->tensors[0].raw_ptr()};
if (param.format == Param::Format::NCHW32) {
megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
filter_tensor = TensorND{
args.filter_tensor->layout,
args.preprocessed_filter->tensors[1].raw_ptr()};
} else {
filter_tensor = *args.filter_tensor;
}
} else {
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
auto cvt = args.handle->create_operator<TypeCvt>();
auto float_bias_layout = *args.bias_layout;
auto converted_bias_layout = *args.bias_layout;
converted_bias_layout.dtype = dtype::QuantizedS32(alpha);
float_bias_layout.dtype = dtype::Float32();
auto bias_size_in_bytes = float_bias_layout.span().dist_byte();
megdnn_assert(args.workspace.size >= bias_size_in_bytes);
cvt->exec(
{args.bias_tensor->raw_ptr(), converted_bias_layout},
TensorND{workspace_ptr, float_bias_layout});

bias_ptr = workspace_ptr;
workspace_ptr += bias_size_in_bytes;
workspace_size -= bias_size_in_bytes;
}
if (param.format == Param::Format::NCHW32) {
size_t reorder_workspace_size =
args.filter_tensor->layout.span().dist_byte() +
args.bias_tensor->layout.span().dist_byte();
auto reorder_filter_ptr = workspace_ptr;
auto reorder_bias_ptr =
workspace_ptr + args.filter_tensor->layout.span().dist_byte();
cudnn_reorder_filer_and_bias_nchw32(
cudnn_handle(args.opr->handle()), args.filter_tensor->raw_ptr(),
args.filter_meta, bias_ptr, reorder_filter_ptr, reorder_bias_ptr);
filter_tensor = TensorND(args.filter_tensor->layout, reorder_filter_ptr);
bias_ptr = reorder_bias_ptr;
workspace_ptr += reorder_workspace_size;
workspace_size -= reorder_workspace_size;
} else {
filter_tensor = *args.filter_tensor;
}
}

bias_tensor = TensorND{args.bias_tensor->layout, bias_ptr};
ExecArgs exec_args{
const_cast<ConvBiasForwardImpl*>(args.opr),
*args.src_tensor,
filter_tensor,
bias_tensor,
*args.z_tensor,
*args.dst_tensor,
args.workspace};
Workspace cudnn_workspace{workspace_ptr, workspace_size};
cudnn_execute(exec_args, cudnn_workspace, alpha, beta);

// Noline
switch (args.nonlinear_mode) {
case param::ConvBias::NonlineMode::RELU:
break;
case param::ConvBias::NonlineMode::SIGMOID: {
megdnn_assert(
args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED);
auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
break;
}
case param::ConvBias::NonlineMode::IDENTITY:
break;
case param::ConvBias::NonlineMode::H_SWISH: {
megdnn_assert(
args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED ||
(args.dst_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW));
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) {
auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>();
elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
} else {
auto&& elem_opr = args.handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
}
break;
}
default:
megdnn_throw("unsupported NonlineMode");
}
}

size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::
get_preprocess_workspace_in_bytes(const SizeArgs& args) const {
auto&& param = args.opr->param();
if (param.format == Param::Format::NCHW32) {
return args.bias_layout->span().dist_byte();
}
return 0_z;
}

SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::
deduce_preprocessed_filter_layout(const SizeArgs& args) const {
auto&& param = args.opr->param();
if (param.format == Param::Format::NCHW32) {
return {args.bias_layout->collapse_contiguous(),
args.filter_layout->collapse_contiguous()};
} else {
return {args.bias_layout->collapse_contiguous()};
}
}

void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec_preprocess(
const ExecArgs& args) const {
float alpha, beta;
std::tie(alpha, beta) = cudnn_get_conv_bias_act_scale_param(
args.src_tensor->layout, args.dst_tensor->layout,
args.filter_tensor->layout, args.bias_tensor->layout,
args.z_tensor->layout);
MEGDNN_MARK_USED_VAR(beta);

auto workspace_ptr = args.workspace.raw_ptr;
auto workspace_size = args.workspace.size;
auto bias_ptr = workspace_size > 0 ? workspace_ptr
: args.preprocessed_filter->tensors[0].raw_ptr();
if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
auto cvt = args.handle->create_operator<TypeCvt>();
auto float_bias_layout = *args.bias_layout;
auto converted_bias_layout = *args.bias_layout;
converted_bias_layout.dtype = dtype::QuantizedS32(alpha);
float_bias_layout.dtype = dtype::Float32();

cvt->exec(
{args.bias_tensor->raw_ptr(), converted_bias_layout},
TensorND{bias_ptr, float_bias_layout});
}
if (args.opr->param().format == Param::Format::NCHW32) {
auto reorder_filter_ptr = args.preprocessed_filter->tensors[1].raw_ptr();
auto reorder_bias_ptr = args.preprocessed_filter->tensors[0].raw_ptr();
cudnn_reorder_filer_and_bias_nchw32(
cudnn_handle(args.opr->handle()), args.filter_tensor->raw_ptr(),
args.filter_meta, bias_ptr, reorder_filter_ptr, reorder_bias_ptr);
}
}

// vim: syntax=cpp.doxygen

+ 145
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp View File

@@ -0,0 +1,145 @@
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megdnn/oprs/general.h"

#include "./algo.h"

#include "src/common/conv_bias.h"
#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/utils.h"

#if CUDNN_VERSION >= 8004
using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

namespace {
TensorLayout canonical_bias_layout(
const TensorLayout& bias_layout, const param::ConvBias::Format format) {
int64_t vector_count, vector_dimension;
std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format);
size_t channel = bias_layout[vector_dimension] * vector_count;
if (bias_layout.dtype.category() != DTypeCategory::FLOAT) {
return TensorLayout{{1, channel, 1, 1}, dtype::Float32()};
}
return TensorLayout{{1, channel, 1, 1}, bias_layout.dtype};
}
} // namespace

bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
if (param.format == param::ConvBias::Format::NCHW4_NCHW32 ||
param.format == param::ConvBias::Format::NCHW32_NCHW4 ||
param.format == param::ConvBias::Format::NCHW4_NCHW ||
param.format == param::ConvBias::Format::NCHW8 ||
param.format == param::ConvBias::Format::NCHW64 ||
param.format == param::ConvBias::Format::CHWN4)
return false;
if (param.format != Param::Format::NCHW && param.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
}
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
return false;
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm)
return false;
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}

if (args.bias_layout->ndim == 0 ||
!check_bias_share_in_channel(*(args.bias_layout), param.format)) {
return false;
}

// FIXME: cudnn cannot handle the case when the initial value of dst tensor
// contains nan and beta is zero, because the result of 0.f * nan is still
// nan
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
param.format == param::ConvBias::Format::NCHW) {
return false;
}

if (param.format == param::ConvBias::Format::NCHW32) {
// sm version
auto&& device_prop = current_device_prop();
if (device_prop.major < 7 || (device_prop.major == 7 && device_prop.minor < 5))
return false;
}

switch (args.nonlinear_mode) {
case param::ConvBias::NonlineMode::RELU:
case param::ConvBias::NonlineMode::IDENTITY:
break;
case param::ConvBias::NonlineMode::SIGMOID:
// forbits sigmoid for quantized
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
return false;
break;
case param::ConvBias::NonlineMode::H_SWISH:
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED)
break;
return false;
default:
megdnn_throw("unsupported NonlineMode");
}

auto bias_layout =
canonical_bias_layout(*args.bias_layout, args.opr->param().format);
auto plan = get_heuristic_plan_from_opr(
static_cast<const ConvBiasForward*>(args.opr), *args.src_layout,
*args.dst_layout, *args.filter_layout, bias_layout, *args.z_layout,
args.filter_meta);
return plan != nullptr;
}

size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::cudnn_get_workspace_in_bytes(
const SizeArgs& args) const {
auto bias_layout =
canonical_bias_layout(*args.bias_layout, args.opr->param().format);
auto plan = get_heuristic_plan_from_opr(
static_cast<const ConvBiasForward*>(args.opr), *args.src_layout,
*args.dst_layout, *args.filter_layout, bias_layout, *args.z_layout,
args.filter_meta);
megdnn_assert(
plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
return plan->getWorkspaceSize();
}

void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::cudnn_execute(
const ExecArgs& args, const Workspace& workspace, float alpha,
float beta) const {
auto&& bias_layout =
canonical_bias_layout(args.bias_tensor->layout, args.opr->param().format);
auto plan = get_heuristic_plan_from_opr(
static_cast<const ConvBiasForward*>(args.opr), args.src_tensor->layout,
args.dst_tensor->layout, args.filter_tensor->layout, bias_layout,
args.z_tensor->layout, args.filter_meta);
megdnn_assert(
plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
auto&& handle = cudnn_handle(args.handle);
TensorND bias_tensor{args.bias_tensor->raw_ptr(), bias_layout};
run_conv_bias_act_with_plan(
handle, *plan, *args.src_tensor, *args.dst_tensor, *args.filter_tensor,
bias_tensor, *args.z_tensor, workspace);
}

#endif

// vim: syntax=cpp.doxygen

+ 98
- 0
dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp View File

@@ -0,0 +1,98 @@
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/utils.h"

#if CUDNN_VERSION >= 8004
using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

bool ConvBiasForwardImpl::AlgoCUDNNConvV8::is_available(const SizeArgs& args) const {
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
}

if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return false;
}

// FIXME: cudnn cannot handle the case when the initial value of dst tensor
// contains nan and beta is zero, because the result of 0.f * nan is still
// nan
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float32 &&
args.opr->param().format == param::ConvBias::Format::NCHW) {
return false;
}

auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
}
SizeArgs conv_args = args;
conv_args.dst_layout = &dst_layout;

if (!is_cudnn_supported(conv_args))
return false;

auto conv_opr = args.handle->create_operator<ConvolutionForward>();
conv_opr->param() = get_param_convolution(args);
ConvolutionForward::CanonizedFilterMeta fm;
fm.copy_from(args.filter_meta);
auto plan = get_heuristic_plan_from_opr(
conv_opr.get(), *conv_args.src_layout, *conv_args.dst_layout,
*conv_args.filter_layout, {}, {}, fm);
return plan != nullptr;
}

size_t ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_get_workspace_in_bytes(
const SizeArgs& args) const {
auto conv_opr = args.handle->create_operator<ConvolutionForward>();
conv_opr->param() = get_param_convolution(args);
ConvolutionForward::CanonizedFilterMeta fm;
fm.copy_from(args.filter_meta);
auto plan = get_heuristic_plan_from_opr(
conv_opr.get(), *args.src_layout, *args.dst_layout, *args.filter_layout, {},
{}, fm);
megdnn_assert(
plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
return plan->getWorkspaceSize();
}

void ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_execute(
const ExecArgs& args, const Workspace& workspace) const {
auto conv_opr = args.handle->create_operator<ConvolutionForward>();
conv_opr->param() = get_param_convolution(args);
ConvolutionForward::CanonizedFilterMeta fm;
fm.copy_from(args.filter_meta);
auto plan = get_heuristic_plan_from_opr(
conv_opr.get(), args.src_tensor->layout, args.dst_tensor->layout,
args.filter_tensor->layout, {}, {}, fm);
megdnn_assert(
plan != nullptr, "algo(%s) cannot find execution from heuristics", name());
auto&& handle = cudnn_handle(args.handle);
run_single_conv_with_plan(
handle, *plan, *args.src_tensor, *args.dst_tensor, *args.filter_tensor,
workspace);
}
#endif

// vim: syntax=cpp.doxygen

+ 53
- 1
dnn/src/cuda/conv_bias/helper.cpp View File

@@ -197,8 +197,60 @@ void flip_filter(
ref_ptr.reset(workspace.raw_ptr);
}

} // namespace conv_bias
std::pair<float, float> cudnn_get_conv_bias_act_scale_param(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& w,
const TensorLayout& b, const TensorLayout& z) {
float alpha = 1.f, beta = 0.f;
if (z.ndim > 0)
beta = 1.f;

auto get_scale = [](const DType& dtype) -> float {
megdnn_assert(dtype.category() == DTypeCategory::QUANTIZED);
switch (dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return dtype.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
default:
megdnn_assert_internal(0);
}
};

auto x_dtype = x.dtype, y_dtype = y.dtype, w_dtype = w.dtype;
megdnn_assert(
(x_dtype.category() == y_dtype.category()) ||
(x_dtype.enumv() == DTypeEnum::QuantizedS8 &&
y_dtype.enumv() == DTypeEnum::Float32));
megdnn_assert(x_dtype.category() == w_dtype.category());

if (x_dtype.category() == DTypeCategory::QUANTIZED) {
auto expected_bias_scale = get_scale(x_dtype) * get_scale(w_dtype);
alpha = expected_bias_scale;
if (y_dtype.category() == DTypeCategory::QUANTIZED)
alpha /= get_scale(y_dtype);
if (z.ndim > 0 && z.dtype.category() == DTypeCategory::QUANTIZED) {
beta = get_scale(z.dtype) / get_scale(y_dtype);
}
if (b.dtype.category() == DTypeCategory::QUANTIZED) {
megdnn_assert(fabs(expected_bias_scale - get_scale(b.dtype)) < 1e-4);
}
}
return {alpha, beta};
}

void cudnn_reorder_filer_and_bias_nchw32(
const cudnnHandle_t& handle, const void* filter_ptr,
const CanonizedFilterMeta& fm, const void* bias_ptr, void* reordered_filter_ptr,
void* reordered_bias_ptr) {
FilterDesc<param::ConvBias> filter_desc;
filter_desc.set(fm);
int reorder_bias = bias_ptr != nullptr;
cudnn_check(cudnnReorderFilterAndBias(
handle, filter_desc.desc, CUDNN_DEFAULT_REORDER, filter_ptr,
reordered_filter_ptr, reorder_bias, bias_ptr, reordered_bias_ptr));
}
} // namespace conv_bias
} // namespace cuda
} // namespace megdnn



+ 9
- 0
dnn/src/cuda/conv_bias/helper.h View File

@@ -113,6 +113,15 @@ struct CUDNNForwardDescs {
}
};

std::pair<float, float> cudnn_get_conv_bias_act_scale_param(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& w,
const TensorLayout& b, const TensorLayout& z);

void cudnn_reorder_filer_and_bias_nchw32(
const cudnnHandle_t& handle, const void* filter_ptr,
const CanonizedFilterMeta& fm, const void* bias_ptr, void* reordered_filter_ptr,
void* reordered_bias_ptr);

} // namespace conv_bias
} // namespace cuda
} // namespace megdnn


+ 11
- 0
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -47,6 +47,17 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
using namespace conv_bias;
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst};
#if CUDNN_VERSION >= 8004
if (sm_algo_pack.cudnn_conv_v8.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.cudnn_conv_v8;
}
if (sm_algo_pack.cudnn_conv_bias_activation_v8.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
return &sm_algo_pack.cudnn_conv_bias_activation_v8;
}
#endif

auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();


+ 7
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -1,6 +1,7 @@
#pragma once
#include "../elemwise/opr_impl.h"
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_with_check.h"

namespace megdnn {
namespace cuda {
@@ -65,6 +66,12 @@ public:
// The following algorithms are suitable for channel wise convolution
class AlgoFloat32NCHWFMAImplicitBatchedGemm;
class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
class AlgoCUDNNConvBase;
class AlgoCUDNNConvBiasActivationBase;
#if CUDNN_VERSION > 8004
class AlgoCUDNNConvV8;
class AlgoCUDNNConvBiasActivationV8;
#endif

class AlgoPack;



+ 685
- 0
dnn/src/cuda/cudnn_wrapper_v8.cpp View File

@@ -0,0 +1,685 @@
/**
* \file dnn/src/cuda/cudnn_wrapper_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/cudnn_wrapper.h"

#include "src/common/utils.h"
#include "src/cuda/utils.h"

#include "src/cuda/conv_bias/helper.h"

#include "cudnn_frontend_EngineConfigGenerator.h"

#include "megdnn/heuristic_cache.h"

using namespace megdnn;
using namespace cuda;

// helper functions for underlying descriptors
namespace {
cudnnDataType_t get_cudnn_data_type(DType type) {
switch (type.enumv()) {
case DTypeEnum::Float32:
return CUDNN_DATA_FLOAT;
case DTypeEnum::Float16:
return CUDNN_DATA_HALF;
case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32:
return CUDNN_DATA_INT32;
case DTypeEnum::QuantizedS8:
case DTypeEnum::Int8:
return CUDNN_DATA_INT8;
default:
megdnn_throw("dtype must be float16/float32/int8/qint8/int32/qint32");
}
}

cudnnDataType_t get_compute_type(
DType type, param::Convolution::ComputeMode comp_mode) {
if (type.enumv() == DTypeEnum::Float32) {
return CUDNN_DATA_FLOAT;
} else if (type.enumv() == DTypeEnum::Float16) {
return get_compute_type_fp16(comp_mode);
} else if (
type.category() == DTypeCategory::INT ||
type.category() == DTypeCategory::QUANTIZED) {
return CUDNN_DATA_INT32;
} else {
megdnn_throw("unsupported compute type for convolution");
}
}

using Format = param::Convolution::Format;
using IntArrayRef = SmallVector<int64_t>;
std::pair<IntArrayRef, IntArrayRef> get_shape_and_stride(
const TensorLayout& layout, const Format format, int64_t nr_group) {
// DENSE: n, c, h, w
// n, k, p, q; ndim = 4
// GROUP: n, g, c, h, w
// n, g, k, p, q; ndim = 5
static constexpr size_t CUDNN_NDIM = 4;
size_t cudnn_ndim = CUDNN_NDIM;
if (nr_group > 1)
cudnn_ndim += 1;
IntArrayRef shape(cudnn_ndim);
IntArrayRef stride(cudnn_ndim);

if (format == Format::NCHW4 || format == Format::NCHW32)
megdnn_assert_eq_size_t(layout.ndim, 5_z);
else
megdnn_assert_eq_size_t(layout.ndim, 4_z);

size_t c_pos, spatial_pos;
if (format == Format::NCHW || format == Format::NCHW4 || format == Format::NCHW32) {
c_pos = 1;
spatial_pos = 2;
} else {
megdnn_assert(format == Format::NHWC);
c_pos = 3;
spatial_pos = 1;
}
int64_t vector_count, vector_dimension;
std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format);

size_t out_c_pos = nr_group == 1 ? 1 : 2;
size_t out_spatial_pos = nr_group == 1 ? 2 : 3;
// For NCHW4 and NCHW32 we still compute standard strides here to input to cuDNN
// functions. We will manually scale by resizeFactor in the cpu ref.
shape[0] = layout[0];
if (nr_group > 1)
shape[1] = nr_group;
shape[out_c_pos] = layout[c_pos] / nr_group;
shape[out_spatial_pos] = layout[spatial_pos];
shape[out_spatial_pos + 1] = layout[spatial_pos + 1];
if (c_pos == 1) {
stride[cudnn_ndim - 1] = 1;
for (int i = cudnn_ndim - 2; i >= 0; --i) {
stride[i] = stride[i + 1] * shape[i + 1];
}
} else {
megdnn_assert(c_pos == 3); // Here we assume that the format is NHWC
stride[out_c_pos] = 1;
if (nr_group > 1)
stride[1] = shape[out_c_pos] * stride[out_c_pos];
stride[out_spatial_pos + 1] = stride[1] * shape[1];
stride[out_spatial_pos] =
stride[out_spatial_pos + 1] * shape[out_spatial_pos + 1];
stride[0] = stride[out_spatial_pos] * shape[out_spatial_pos];
}
return {shape, stride};
}

/* --------------- make cudnn-frontend tensor descriptor --------------- */
auto make_tensor_descriptor(
int64_t id, uint8_t alignment, const TensorLayout& layout, const Format format,
int64_t nr_group, bool is_virtual = false) {
int64_t vector_count, vector_dimension;
std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format);
IntArrayRef shape, stride;
std::tie(shape, stride) = get_shape_and_stride(layout, format, nr_group);
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(stride.size(), stride.data())
.setId(id)
.setAlignment(alignment)
.setDataType(get_cudnn_data_type(layout.dtype))
.setVirtual(is_virtual)
.setVectorCountAndDimension(vector_count, vector_dimension)
.build();
}

/* --------------- make cudnn-frontend filter descriptor --------------- */
template <typename FilterMeta>
cudnn_frontend::Tensor make_filter_descriptor(uint8_t alignment, const FilterMeta& fm) {
// DENSE: k, c, r, s; ndim = 4
// GROUP: g, k, c, r, s; ndim = 5
// generate shape and stride
static constexpr size_t CUDNN_NDIM = 4;
size_t cudnn_ndim = CUDNN_NDIM;
if (fm.group > 1)
cudnn_ndim += 1;
IntArrayRef shape(cudnn_ndim), stride(cudnn_ndim);
auto format = fm.format;
int64_t vector_count, vector_dimension;
std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format);

int64_t group = fm.group;
size_t out_ch_pos = group == 1 ? 0 : 1;
size_t in_ch_pos = group == 1 ? 1 : 2;
size_t filter_start = group == 1 ? 2 : 3;
if (group > 1)
shape[0] = group;
shape[out_ch_pos] = fm.ocpg;
shape[in_ch_pos] = fm.icpg / vector_count;
shape[filter_start] = fm.spatial[0];
shape[filter_start + 1] = fm.spatial[1];
if (format == Format::NCHW || format == Format::NCHW4 || format == Format::NCHW32) {
stride[cudnn_ndim - 1] = 1;
for (int i = cudnn_ndim - 2; i >= 0; --i) {
stride[i] = stride[i + 1] * shape[i + 1];
}
} else {
megdnn_assert(
format == Format::NHWC); // Here we assume that the format is NHWC
stride[in_ch_pos] = 1;
stride[filter_start + 1] = stride[in_ch_pos] * shape[in_ch_pos];
stride[filter_start] = stride[filter_start + 1] * shape[filter_start + 1];
stride[out_ch_pos] = stride[filter_start] * shape[filter_start];
if (group > 1)
stride[0] = stride[out_ch_pos] * shape[out_ch_pos];
}
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(stride.size(), stride.data())
.setId('w') // weight descriptor
.setAlignment(alignment)
.setDataType(get_cudnn_data_type(fm.dtype))
.setVectorCountAndDimension(vector_count, vector_dimension)
.build();
}

/* --------------- make cudnn-frontend conv descriptor --------------- */
template <typename Param>
cudnn_frontend::ConvDesc_v8 make_conv_descriptor(
cudnnDataType_t data_type, const Param& param) {
IntArrayRef padding = {param.pad_h, param.pad_w};
IntArrayRef stride = {param.stride_h, param.stride_w};
IntArrayRef dilation = {param.dilate_h, param.dilate_w};
uint64_t conv_dim = stride.size();
cudnnConvolutionMode_t mode;
switch (param.mode) {
case Param::Mode::CROSS_CORRELATION:
mode = CUDNN_CROSS_CORRELATION;
break;
case Param::Mode::CONVOLUTION:
mode = CUDNN_CONVOLUTION;
break;
default:
megdnn_throw("conv mode must be conv or xcorr.");
}
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(mode)
.setNDims(conv_dim)
.setStrides(conv_dim, stride.data())
.setPrePadding(conv_dim, padding.data())
.setPostPadding(conv_dim, padding.data())
.setDilation(conv_dim, dilation.data())
.build();
}

/* --------------- make cudnn-frontend activation descriptor --------------- */
auto make_activation_descriptor(
DType data_type, const param::ConvBias::NonlineMode nonline_mode) {
cudnnPointwiseMode_t mode;
using NonlineMode = param::ConvBias::NonlineMode;
switch (nonline_mode) {
case NonlineMode::RELU:
mode = CUDNN_POINTWISE_RELU_FWD;
break;
case NonlineMode::SIGMOID:
mode = CUDNN_POINTWISE_SIGMOID_FWD;
break;
default:
megdnn_throw("unsupported non linear mode");
}
return cudnn_frontend::PointWiseDescBuilder()
.setMode(mode)
.setMathPrecision(get_cudnn_data_type(data_type))
.build();
}

// high-level api for convolution execution
struct StaticData {
using Key = megdnn::HeuristicCache::Key;
using KeyStorage = megdnn::HeuristicCache::KeyStorage;
using KeyHash = megdnn::HeuristicCache::Hash;
using Result = cudnn_frontend::ExecutionPlan;
using CudnnFrontendExecutionPlanCache =
std::unordered_map<KeyStorage, Result, KeyHash>;
CudnnFrontendExecutionPlanCache cache;
#if __DEPLOY_ON_XP_SP2__
size_t cache_mutex;
#else
std::mutex cache_mutex;
#endif
cudnnBackendHeurMode_t heur_mode = CUDNN_HEUR_MODE_INSTANT;
bool deterministic = true;
};

StaticData& static_data() {
static StaticData inst;
return inst;
}

template <typename Opr>
struct CudnnBackendOpTypeTrait;

template <>
struct CudnnBackendOpTypeTrait<ConvolutionForward> {
static constexpr cudnnBackendDescriptorType_t OPERATION =
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
};

template <>
struct CudnnBackendOpTypeTrait<ConvolutionBackwardData> {
static constexpr cudnnBackendDescriptorType_t OPERATION =
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR;
};

template <>
struct CudnnBackendOpTypeTrait<ConvolutionBackwardFilter> {
static constexpr cudnnBackendDescriptorType_t OPERATION =
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR;
};

auto build_opgraph(
const cudnnHandle_t& handle, const cudnnBackendDescriptorType_t operation,
const cudnn_frontend::Tensor& x, const cudnn_frontend::Tensor& y,
const cudnn_frontend::Tensor& w, const cudnn_frontend::ConvDesc_v8& conv_desc) {
auto op = cudnn_frontend::OperationBuilder(operation)
.setxDesc(x)
.setyDesc(y)
.setwDesc(w)
.setcDesc(conv_desc)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(1, ops.data())
.build();
return op_graph;
}

auto build_opgraph_fused(
const cudnnHandle_t& handle, const cudnn_frontend::Tensor& x,
const cudnn_frontend::Tensor& y, const cudnn_frontend::Tensor& w,
const cudnn_frontend::Tensor& b, const cudnn_frontend::Tensor& z,
const cudnn_frontend::Tensor& after_add,
const cudnn_frontend::Tensor& after_bias,
const cudnn_frontend::Tensor& after_conv,
const cudnn_frontend::ConvDesc_v8& conv_desc,
const cudnn_frontend::PointWiseDesc_v8& act_desc, float alpha, float beta) {
const auto precision = CUDNN_DATA_FLOAT;

// add z
auto add_desc1 = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(precision)
.build();
// add bias
auto add_desc2 = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(precision)
.build();

// create conv node
auto conv_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(x)
.setyDesc(after_conv)
.setwDesc(w)
.setcDesc(conv_desc)
.build();

// create add z node
auto add_op1 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(z)
.setyDesc(after_add)
.setpwDesc(add_desc1)
.setAlpha(alpha)
.setAlpha2(beta)
.build();

// create add bias node
auto add_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op1.getOutputTensor())
.setbDesc(b)
.setyDesc(after_bias)
.setpwDesc(add_desc2)
.build();

// create act node
auto act_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op2.getOutputTensor())
.setyDesc(y)
.setpwDesc(act_desc)
.build();

std::array<cudnn_frontend::Operation const*, 4> ops = {
&conv_op, &add_op1, &add_op2, &act_op};

auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(ops.size(), ops.data())
.build();
return op_graph;
}

auto build_opgraph_fused_nonactivation(
const cudnnHandle_t& handle, const cudnn_frontend::Tensor& x,
const cudnn_frontend::Tensor& y, const cudnn_frontend::Tensor& w,
const cudnn_frontend::Tensor& b, const cudnn_frontend::Tensor& z,
const cudnn_frontend::Tensor& after_add,
const cudnn_frontend::Tensor& after_conv,
const cudnn_frontend::ConvDesc_v8& conv_desc, float alpha, float beta) {
const auto precision = CUDNN_DATA_FLOAT;

// add z
auto add_desc1 = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(precision)
.build();
// add bias
auto add_desc2 = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(precision)
.build();

// create conv node
auto conv_op = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(x)
.setyDesc(after_conv)
.setwDesc(w)
.setcDesc(conv_desc)
.build();

// create add z node
auto add_op1 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(z)
.setyDesc(after_add)
.setpwDesc(add_desc1)
.setAlpha(alpha)
.setAlpha2(beta)
.build();

// create add bias node
auto add_op2 = cudnn_frontend::OperationBuilder(
CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op1.getOutputTensor())
.setbDesc(b)
.setyDesc(y)
.setpwDesc(add_desc2)
.build();

std::array<cudnn_frontend::Operation const*, 3> ops = {
&conv_op, &add_op1, &add_op2};

auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(ops.size(), ops.data())
.build();
return op_graph;
}

void filter_engine_configs(
cudnn_frontend::EngineConfigList& from, cudnn_frontend::EngineConfigList& to,
bool deterministic) {
auto filter = [&deterministic](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(
c)) {
return true;
}
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(
c)) {
return true;
}
return false;
};
cudnn_frontend::filter(from, to, filter);
}
}; // namespace

/* --------- get heuristic plan from megdnn opr -------- */
template <typename Opr>
cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr(
const Opr* opr, const TensorLayout& x, const TensorLayout& y,
const TensorLayout& w, const TensorLayout& b, const TensorLayout& z,
const typename Opr::CanonizedFilterMeta& fm) {
auto&& param = opr->param();
TensorLayoutArray layouts{x, y, w};
auto key = StaticData::Key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(),
&param, sizeof(param)}
.build_key_storage();
auto& cache = static_data().cache;
{
MEGDNN_LOCK_GUARD(static_data().cache_mutex);
auto iter = cache.find(key);
if (iter != cache.end()) {
return &iter->second;
}
}

size_t aligned = 16;
uint8_t alignment = std::min(opr->handle()->alignment_requirement(), aligned);
auto&& handle = cudnn_handle(opr->handle());
auto&& x_desc = make_tensor_descriptor('x', alignment, x, fm.format, fm.group);
auto&& y_desc = make_tensor_descriptor('y', alignment, y, fm.format, fm.group);
auto&& w_desc = make_filter_descriptor(alignment, fm);
auto compute_type = get_compute_type(x.dtype, param.compute_mode);
auto&& conv_desc = make_conv_descriptor(compute_type, param);
constexpr auto operation = CudnnBackendOpTypeTrait<Opr>::OPERATION;
auto op_graph = build_opgraph(handle, operation, x_desc, y_desc, w_desc, conv_desc);
auto deterministic = static_data().deterministic;
auto heur_mode = static_data().heur_mode;
auto heurgen_method = [&deterministic,
&heur_mode](cudnn_frontend::OperationGraph& op_graph)
-> cudnn_frontend::EngineConfigList {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(heur_mode)
.build();
auto& engine_configs =
heuristics.getEngineConfig(heuristics.getEngineConfigCount());
cudnn_frontend::EngineConfigList filtered_configs;
filter_engine_configs(engine_configs, filtered_configs, deterministic);
return filtered_configs;
};

auto fallback_method = [&deterministic, &heur_mode,
&operation](cudnn_frontend::OperationGraph& op_graph)
-> cudnn_frontend::EngineConfigList {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(operation)
.build();
auto& fallback_list = fallback.getFallbackList();
cudnn_frontend::EngineConfigList filtered_configs;
filter_engine_configs(fallback_list, filtered_configs, deterministic);
return filtered_configs;
};

std::array<cudnn_frontend::GeneratorSource const, 2> sources = {
heurgen_method, fallback_method};

cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data());
auto configs = generator.generate_engine_config(op_graph);

for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config)
.build();
auto workspace_size = plan.getWorkspaceSize();
MEGDNN_MARK_USED_VAR(workspace_size);
MEGDNN_LOCK_GUARD(static_data().cache_mutex);
auto insert = cache.insert(std::make_pair(key, std::move(plan)));
return &insert.first->second;
} catch (cudnn_frontend::cudnnException& e) {
continue;
}
}
return nullptr;
}

#define INST(_Opr) \
template cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr( \
const _Opr* opr, const TensorLayout& x, const TensorLayout& y, \
const TensorLayout& w, const TensorLayout& b, const TensorLayout& z, \
const typename _Opr::CanonizedFilterMeta& fm);

INST(ConvolutionForward);
INST(ConvolutionBackwardData);
INST(ConvolutionBackwardFilter);

/* --------- get heuristic plan from conv_bias opr -------- */
template <>
cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr(
const ConvBiasForward* opr, const TensorLayout& x, const TensorLayout& y,
const TensorLayout& w, const TensorLayout& b, const TensorLayout& z,
const typename ConvBiasForward::CanonizedFilterMeta& fm) {
auto&& param = opr->param();
TensorLayoutArray layouts{x, y, w, b, z};
auto key = StaticData::Key{opr->handle(), opr->get_opr_type(),
layouts.data(), layouts.size(),
&param, sizeof(param)}
.build_key_storage();
auto& cache = static_data().cache;
{
MEGDNN_LOCK_GUARD(static_data().cache_mutex);
auto iter = cache.find(key);
if (iter != cache.end()) {
return &iter->second;
}
}

size_t aligned = 16;
uint8_t alignment = std::min(opr->handle()->alignment_requirement(), aligned);
auto&& handle = cudnn_handle(opr->handle());
auto&& x_desc = make_tensor_descriptor('x', alignment, x, fm.format, fm.group);
auto&& y_desc = make_tensor_descriptor('y', alignment, y, fm.format, fm.group);
auto&& w_desc = make_filter_descriptor(alignment, fm);
auto&& z_desc = make_tensor_descriptor('z', alignment, y, fm.format, fm.group);
auto&& b_desc = make_tensor_descriptor('b', alignment, b, Format::NCHW, fm.group);
auto&& after_conv =
make_tensor_descriptor('C', alignment, y, fm.format, fm.group, true);
auto&& after_add =
make_tensor_descriptor('A', alignment, y, fm.format, fm.group, true);
auto&& after_bias =
make_tensor_descriptor('B', alignment, y, fm.format, fm.group, true);
auto compute_type = get_compute_type(x.dtype, param.compute_mode);
auto&& conv_desc = make_conv_descriptor(compute_type, param);
float alpha, beta;
std::tie(alpha, beta) =
conv_bias::cudnn_get_conv_bias_act_scale_param(x, y, w, b, z);
// Because the OperationGraph has no public copy constructor and default
// constructor, here we use a lambda function to bypass the compile error.
auto get_op_graph = [&]() {
if (param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY) {
return build_opgraph_fused_nonactivation(
handle, x_desc, y_desc, w_desc, b_desc, z_desc, after_add,
after_conv, conv_desc, alpha, beta);
} else {
auto&& act_desc =
make_activation_descriptor(dtype::Float32(), param.nonlineMode);
return build_opgraph_fused(
handle, x_desc, y_desc, w_desc, b_desc, z_desc, after_add,
after_bias, after_conv, conv_desc, act_desc, alpha, beta);
}
};
auto op_graph = get_op_graph();
auto deterministic = static_data().deterministic;
auto heur_mode = static_data().heur_mode;
auto heurgen_method = [&deterministic,
&heur_mode](cudnn_frontend::OperationGraph& op_graph)
-> cudnn_frontend::EngineConfigList {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(heur_mode)
.build();
auto& engine_configs =
heuristics.getEngineConfig(heuristics.getEngineConfigCount());
cudnn_frontend::EngineConfigList filtered_configs;
filter_engine_configs(engine_configs, filtered_configs, deterministic);
return filtered_configs;
};

std::array<cudnn_frontend::GeneratorSource const, 1> sources = {heurgen_method};

cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data());
auto configs = generator.generate_engine_config(op_graph);

for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config)
.build();
auto workspace_size = plan.getWorkspaceSize();
MEGDNN_MARK_USED_VAR(workspace_size);
MEGDNN_LOCK_GUARD(static_data().cache_mutex);
auto insert = cache.insert(std::make_pair(key, std::move(plan)));
return &insert.first->second;
} catch (cudnn_frontend::cudnnException& e) {
continue;
}
}
return nullptr;
}

/* ------ impl for running a single conv ----- */
void megdnn::cuda::run_single_conv_with_plan(
const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
const TensorND& x, const TensorND& y, const TensorND& w,
const Workspace& workspace) {
size_t workspace_size = plan.getWorkspaceSize();
megdnn_assert(
workspace.size >= workspace_size,
"workspace does not meet the requirement of execution "
"plan(got:%zu,expected:%zu)",
workspace.size, workspace_size);
void* data_ptrs[] = {x.raw_ptr(), y.raw_ptr(), w.raw_ptr()};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.raw_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
cudnn_check(cudnnBackendExecute(
handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
}

/* ------ impl for running a fused conv bias activation ----- */
void megdnn::cuda::run_conv_bias_act_with_plan(
const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b,
const TensorND& z, const Workspace& workspace) {
size_t workspace_size = plan.getWorkspaceSize();
megdnn_assert(
workspace.size >= workspace_size,
"workspace does not meet the requirement of execution "
"plan(got:%zu,expected:%zu)",
workspace.size, workspace_size);
void* z_ptr = z.layout.ndim == 0 ? nullptr : z.raw_ptr();
void* data_ptrs[] = {x.raw_ptr(), y.raw_ptr(), w.raw_ptr(), z_ptr, b.raw_ptr()};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.raw_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
cudnn_check(cudnnBackendExecute(
handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
}

// vim: syntax=cpp.doxygen

+ 70
- 0
dnn/src/cuda/cudnn_wrapper_v8.h View File

@@ -0,0 +1,70 @@
/**
* \file dnn/src/cuda/cudnn_wrapper_v8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma GCC diagnostic ignored "-Wreorder"
#include "cudnn_frontend.h"
#pragma GCC diagnostic pop

namespace megdnn {
namespace cuda {
static inline std::pair<int64_t, int64_t> get_vector_count_and_dimension(
const param::Convolution::Format format) {
using Format = param::Convolution::Format;
int64_t vector_count = 1;
int64_t vector_dimension = 1;
switch (format) {
case Format::NCHW:
break;
case Format::NHWC:
vector_dimension = 3;
break;
case Format::NCHW4:
vector_count = 4;
break;
case Format::NCHW32:
vector_count = 32;
break;
default:
megdnn_assert(
false, "unsupported format (got:%u) for cudnn",
static_cast<uint32_t>(format));
}
return {vector_count, vector_dimension};
}

template <typename Opr>
cudnn_frontend::ExecutionPlan* get_heuristic_plan_from_opr(
const Opr* opr, const TensorLayout& x, const TensorLayout& y,
const TensorLayout& w, const TensorLayout& b, const TensorLayout& z,
const typename Opr::CanonizedFilterMeta& fm);

void run_single_conv_with_plan(
const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
const TensorND& x, const TensorND& y, const TensorND& w,
const Workspace& workspace);

void run_conv_bias_act_with_plan(
const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan,
const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b,
const TensorND& z, const Workspace& workspace);

} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 15
- 0
dnn/src/cuda/handle.cpp View File

@@ -58,6 +58,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)");
}
#endif
size_t free, tot;
cudaMemGetInfo(&free, &tot);
printf("before cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n",
free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0,
(tot - free) / 1024.0 / 1024.0);
cudnn_check(cudnnCreate(&m_cudnn_handle));
cublas_check(cublasCreate(&m_cublas_handle));
#if CUDA_VERSION >= 10010
@@ -69,6 +74,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
cudnn_check(cudnnSetStream(m_cudnn_handle, stream()));
cublas_check(cublasSetStream(m_cublas_handle, stream()));

#if CUDNN_VERSION >= 8004
// cudnn_check(cudnnOpsInferVersionCheck());
// cudnn_check(cudnnCnnInferVersionCheck());
#endif

// Note that all cublas scalars (alpha, beta) and scalar results such as dot
// output resides at device side.
cublas_check(cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE));
@@ -82,6 +92,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
cudaMemcpyHostToDevice, stream()));
cuda_check(cudaStreamSynchronize(stream()));

cudaMemGetInfo(&free, &tot);
printf("after cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n",
free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0,
(tot - free) / 1024.0 / 1024.0);

// check tk1
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
m_cusolver_handle = nullptr;


+ 304
- 0
dnn/test/cuda/conv_v8.cpp View File

@@ -0,0 +1,304 @@
/**
* \file dnn/test/cuda/conv_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "test/cuda/fixture.h"

#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "src/cuda/handle.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/conv_bias.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
#include "test/cuda/utils.h"

using namespace megdnn;
using namespace test;
using namespace conv_bias;

#if CUDNN_VERSION >= 8004
TEST_F(CUDA, CONV_V8_FLOAT) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvolutionV8", {})
.c_str()}));

UniformFloatRNG rng(0.f, 1.f);
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32());
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// group
param.sparse = param::ConvBias::Sparse::GROUP;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// NHWC
param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
}

TEST_F(CUDA, CONV_V8_HALF) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvolutionV8", {})
.c_str()}));

UniformFloatRNG rng(0.f, 1.f);
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(3, dtype::Float16())
.set_dtype(4, dtype::Float16())
.set_epsilon(5e-2);
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.compute_mode = param::ConvBias::ComputeMode::FLOAT32;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// group
param.sparse = param::ConvBias::Sparse::GROUP;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// NHWC
param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
}

TEST_F(CUDA, CONV_BIAS_V8_FLOAT) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvBiasActivationV8", {})
.c_str()}));

UniformFloatRNG rng(0.f, 1.f);
UniformFloatRNG crng(0.f, 0.f);
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32());
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// group
param.sparse = param::ConvBias::Sparse::GROUP;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// NHWC
param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
}

TEST_F(CUDA, CONV_BIAS_V8_HALF) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvBiasActivationV8", {})
.c_str()}));

UniformFloatRNG rng(0.f, 1.f);
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(3, dtype::Float16())
.set_dtype(4, dtype::Float16())
.set_epsilon(5e-2);
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.compute_mode = param::ConvBias::ComputeMode::FLOAT32;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// group
param.sparse = param::ConvBias::Sparse::GROUP;
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}});
checker.set_param(param).execs(
{{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}});

// NHWC
param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
}

TEST_F(CUDA, CONV_BIAS_V8_DP4A) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvBiasActivationV8", {})
.c_str()}));

UniformIntRNG rng{-3, 3};
UniformIntRNG bias_rng{-50, 50};
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &bias_rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::QuantizedS8{1.2f})
.set_dtype(1, dtype::QuantizedS8{1.3f})
.set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
.set_dtype(3, dtype::QuantizedS8{1.1f})
.set_dtype(4, dtype::QuantizedS8{1.0f})
.set_epsilon(1 + 1e-3);
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW4;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
checker.set_param(param).execs(
{{1, 16, 7, 7, 4}, {64, 16, 3, 3, 4}, {1, 16, 1, 1, 4}, {}, {}});
checker.set_param(param).execs(
{{1, 16, 7, 7, 4},
{64, 16, 3, 3, 4},
{1, 16, 1, 1, 4},
{1, 16, 7, 7, 4},
{}});

param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY;
checker.set_param(param).execs(
{{1, 16, 7, 7, 4}, {64, 16, 3, 3, 4}, {1, 16, 1, 1, 4}, {}, {}});
checker.set_param(param).execs(
{{1, 16, 7, 7, 4},
{64, 16, 3, 3, 4},
{1, 16, 1, 1, 4},
{1, 16, 7, 7, 4},
{}});

param.format = param::ConvBias::Format::NHWC;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {64, 3, 3, 64}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {64, 3, 3, 64}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
param.sparse = param::ConvBias::Sparse::GROUP;
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}});
checker.set_param(param).execs(
{{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}});
}

TEST_F(CUDA, CONV_BIAS_V8_IMMA) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{
ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>(
"CUDNN:ConvBiasActivationV8", {})
.c_str()}));

UniformIntRNG rng{-3, 3};
UniformIntRNG bias_rng{-50, 50};
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &bias_rng)
.set_rng(3, &rng)
.set_dtype(0, dtype::QuantizedS8{1.2f})
.set_dtype(1, dtype::QuantizedS8{1.3f})
.set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f})
.set_dtype(3, dtype::QuantizedS8{1.1f})
.set_dtype(4, dtype::QuantizedS8{1.0f})
.set_epsilon(1 + 1e-3);
param::ConvBias param;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.format = param::ConvBias::Format::NCHW32;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
checker.set_param(param).execs(
{{1, 2, 7, 7, 32}, {64, 2, 3, 3, 32}, {1, 2, 1, 1, 32}, {}, {}});
checker.set_param(param).execs(
{{1, 2, 7, 7, 32},
{64, 2, 3, 3, 32},
{1, 2, 1, 1, 32},
{1, 2, 7, 7, 32},
{}});

param.nonlineMode = NonlineMode::RELU;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 0;

checker.set_param(param).execs(
{{2, 8, 12, 12, 32}, {512, 8, 1, 1, 32}, {1, 16, 1, 1, 32}, {}, {}});
}

#endif
// vim: syntax=cpp.doxygen

+ 1
- 0
third_party/prepare.sh View File

@@ -94,6 +94,7 @@ function git_submodule_update() {
git submodule sync
git submodule update -f --init midout
git submodule update -f --init flatbuffers
git submodule update -f --init cudnn-frontend
git submodule update -f --init Json
git submodule update -f --init gflags
git submodule update -f --init cpuinfo


Loading…
Cancel
Save