@@ -18,4 +18,39 @@ | |||||
#include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
#include "megdnn/oprs/linalg.h" | #include "megdnn/oprs/linalg.h" | ||||
template <typename Opr> | |||||
struct OprArityTrait; | |||||
template <typename Opr, int _arity_in, int _arity_out> | |||||
struct OprArityTraitTmpl { | |||||
static constexpr int arity_in = _arity_in; | |||||
static constexpr int arity_out = _arity_out; | |||||
static constexpr int arity = arity_in + arity_out; | |||||
}; | |||||
#define INST_ARITY(_Opr, _in, _out) \ | |||||
template <> \ | |||||
struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; | |||||
INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DForward, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareForward, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution, 2, 1); | |||||
INST_ARITY(megdnn::DeformableConvForward, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||||
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||||
INST_ARITY(megdnn::ConvBias, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||||
INST_ARITY(megdnn::MatrixMul, 2, 1); | |||||
INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); | |||||
#undef INST_ARITY | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -47,6 +47,9 @@ namespace megdnn { | |||||
return algo_pack().all_algos_map().at(desc); \ | return algo_pack().all_algos_map().at(desc); \ | ||||
} | } | ||||
#define MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) \ | |||||
cb(AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) | |||||
/** | /** | ||||
* \brief construct algo from AlgorithmDesc | * \brief construct algo from AlgorithmDesc | ||||
*/ | */ | ||||
@@ -323,6 +323,34 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, | |||||
} | } | ||||
} | } | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format) { | |||||
bool share_in_channel = false; | |||||
if (format == param::ConvBias::Format::NCHW || | |||||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWC) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[2] == 1); | |||||
} else if (format == param::ConvBias::Format::NCHW4 || | |||||
format == param::ConvBias::Format::NCHW8 || | |||||
format == param::ConvBias::Format::NCHW32 || | |||||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWCD4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[3] == 1); | |||||
} else { | |||||
megdnn_assert(format == param::ConvBias::Format::CHWN4); | |||||
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} | |||||
return share_in_channel; | |||||
} | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -21,6 +21,9 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, | |||||
const TensorND* conv_dst_tensor, | const TensorND* conv_dst_tensor, | ||||
const TensorND* dst_tensor, | const TensorND* dst_tensor, | ||||
const TensorND* bias_tensor); | const TensorND* bias_tensor); | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -9,7 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/utils.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/batch_conv_bias/algo.h" | #include "src/cuda/batch_conv_bias/algo.h" | ||||
#include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | #include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
@@ -106,7 +106,7 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::is_available( | |||||
using Mode = Param::Mode; | using Mode = Param::Mode; | ||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (!conv_bias::check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
if (!check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
return false; | return false; | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/batch_conv_bias/algo.h" | #include "src/cuda/batch_conv_bias/algo.h" | ||||
#include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | #include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
@@ -86,7 +86,7 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: | |||||
using Mode = Param::Mode; | using Mode = Param::Mode; | ||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (!conv_bias::check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
if (!check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
return false; | return false; | ||||
@@ -115,7 +115,8 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
const char* name() const override { return "CUBLAS"; } | const char* name() const override { return "CUBLAS"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
@@ -128,7 +129,8 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
const char* name() const override { return "CUBLAS_LT"; } | const char* name() const override { return "CUBLAS_LT"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
@@ -173,6 +173,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -280,6 +283,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -352,7 +358,8 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
private: | private: | ||||
@@ -406,7 +413,8 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | ||||
@@ -428,7 +436,14 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
#define cb(attr) \ | |||||
if (m_impl->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
@@ -16,6 +16,7 @@ | |||||
#include "src/cuda/conv_bias/helper.h" | #include "src/cuda/conv_bias/helper.h" | ||||
#include "src/cuda/cudnn_wrapper.h" | #include "src/cuda/cudnn_wrapper.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
} | } | ||||
if (args.bias_layout->ndim == 0 || | if (args.bias_layout->ndim == 0 || | ||||
!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
!check_bias_share_in_channel(*(args.bias_layout), | |||||
args.opr->param().format)) { | args.opr->param().format)) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -168,34 +168,6 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||||
return supported; | return supported; | ||||
} | } | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format) { | |||||
bool share_in_channel = false; | |||||
if (format == param::ConvBias::Format::NCHW || | |||||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWC) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[2] == 1); | |||||
} else if (format == param::ConvBias::Format::NCHW4 || | |||||
format == param::ConvBias::Format::NCHW8 || | |||||
format == param::ConvBias::Format::NCHW32 || | |||||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWCD4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[3] == 1); | |||||
} else { | |||||
megdnn_assert(format == param::ConvBias::Format::CHWN4); | |||||
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} | |||||
return share_in_channel; | |||||
} | |||||
SmallVector<size_t> matmul_get_workspace_bundle( | SmallVector<size_t> matmul_get_workspace_bundle( | ||||
const BiasForwardSizeArgs& args) { | const BiasForwardSizeArgs& args) { | ||||
auto dtype = args.src_layout->dtype; | auto dtype = args.src_layout->dtype; | ||||
@@ -126,9 +126,6 @@ namespace conv_bias { | |||||
} | } | ||||
}; | }; | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format); | |||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -83,7 +84,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -71,7 +72,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -118,7 +119,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter:: | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -118,7 +119,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth:: | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -14,6 +14,7 @@ | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -32,7 +33,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | ||||
@@ -13,6 +13,7 @@ | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format == Format::NCHW4_NCHW32) { | if (param.format == Format::NCHW4_NCHW32) { | ||||
@@ -12,6 +12,7 @@ | |||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/convolution_helper/bias_visitor.cuh" | #include "src/cuda/convolution_helper/bias_visitor.cuh" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
@@ -127,6 +127,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
@@ -158,7 +161,8 @@ public: | |||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -123,6 +123,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -155,7 +158,8 @@ public: | |||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -119,6 +119,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -112,6 +112,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -106,7 +106,8 @@ public: | |||||
const char* name() const override { return "1x1x1"; } | const char* name() const override { return "1x1x1"; } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | ||||
}; | }; | ||||
@@ -126,10 +127,17 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
#define cb(attr) \ | |||||
if (m_impl->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
return ret; | return ret; | ||||
} | } | ||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
@@ -157,6 +165,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -470,9 +470,9 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||||
#define V(v) V1(v) | #define V(v) V1(v) | ||||
#define DEF_NAME(NAME) \ | #define DEF_NAME(NAME) \ | ||||
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | ||||
#define DEF_ALGO(NAME, PROD) \ | |||||
{ \ | |||||
NAME, { DEF_NAME(NAME), PROD } \ | |||||
#define DEF_ALGO(NAME, PROD1, PROD2) \ | |||||
{ \ | |||||
NAME, { DEF_NAME(NAME), PROD1, PROD2 } \ | |||||
} | } | ||||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | ||||
@@ -483,19 +483,18 @@ const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||||
CudnnAlgoPack::conv_bwd_data_algos() { | CudnnAlgoPack::conv_bwd_data_algos() { | ||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
algos = | |||||
{ DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), | |||||
#if CUDNN_MAJOR >= 5 | #if CUDNN_MAJOR >= 5 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true, false), | |||||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true, false), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
return algos; | return algos; | ||||
} | } | ||||
@@ -505,15 +504,16 @@ CudnnAlgoPack::conv_bwd_flt_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), | |||||
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | ||||
true), | |||||
true, false), | |||||
#if CUDNN_MAJOR >= 6 | #if CUDNN_MAJOR >= 6 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true, | |||||
true), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
@@ -522,28 +522,30 @@ CudnnAlgoPack::conv_bwd_flt_algos() { | |||||
return algos; | return algos; | ||||
} | } | ||||
const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | ||||
CudnnAlgoPack::conv_fwd_algos() { | CudnnAlgoPack::conv_fwd_algos() { | ||||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
algos = | |||||
{ DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), | |||||
#if CUDNN_VERSION == 8004 | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true), | |||||
#else | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false), | |||||
#endif | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), | |||||
#if CUDNN_MAJOR >= 5 | #if CUDNN_MAJOR >= 5 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true, false), | |||||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true, false), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
return algos; | return algos; | ||||
} | } | ||||
@@ -553,9 +555,10 @@ CudnnAlgoPack::conv3d_bwd_data_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, | |||||
true), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -568,9 +571,9 @@ CudnnAlgoPack::conv3d_bwd_flt_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -581,10 +584,15 @@ CudnnAlgoPack::conv3d_fwd_algos() { | |||||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), | |||||
#if CUDNN_VERSION == 8004 | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, | |||||
true), | |||||
#else | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, | |||||
false), | |||||
#endif | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -112,6 +112,7 @@ public: | |||||
struct Attr { | struct Attr { | ||||
std::string name; | std::string name; | ||||
bool is_reproducible; | bool is_reproducible; | ||||
bool accuracy_depend_on_batch; | |||||
}; | }; | ||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | ||||
@@ -115,7 +115,8 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | return AlgoAttribute::REPRODUCIBLE | | ||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -142,7 +143,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
#endif | #endif | ||||
@@ -25,7 +25,8 @@ public: | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | ||||
@@ -36,7 +37,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase { | class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
const char* name() const override { return "X86_F32_MKL_PACKA"; } | const char* name() const override { return "X86_F32_MKL_PACKA"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* \file dnn/test/common/accuracy_shake_checker.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/common/accuracy_shake_checker.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
namespace { | |||||
template <typename ctype> | |||||
::testing::AssertionResult assert_tensor_binary_eq( | |||||
const char* expr0, const char* expr1, const char* /*expr2*/, | |||||
const TensorND& v0, const TensorND& v1, const std::string& algo_name) { | |||||
ctype* it0_orig = v0.ptr<ctype>(); | |||||
ctype* it1 = v1.ptr<ctype>(); | |||||
ctype* it0 = it0_orig; | |||||
auto nr_elem = v1.layout.total_nr_elems(); | |||||
auto nr_elem_single_batch = v0.layout.total_nr_elems(); | |||||
for (size_t i = 0; i < nr_elem; ++i) { | |||||
if (i % nr_elem_single_batch == 0) { | |||||
it0 = it0_orig; | |||||
} | |||||
ctype iv0 = *it0, iv1 = *it1; | |||||
if (!good_float(iv0) || !good_float(iv1) || | |||||
memcmp(it0, it1, sizeof(ctype))) { | |||||
Index index(v1.layout, i); | |||||
return ::testing::AssertionFailure() | |||||
<< "Unequal value\n" | |||||
<< "Value of: " << expr1 << "\n" | |||||
<< " Actual: " << (iv1 + 0) << "\n" | |||||
<< "Expected: " << expr0 << "\n" | |||||
<< "Which is: " << (iv0 + 0) << "\n" | |||||
<< "At index: " << index.to_string() << "/" | |||||
<< v1.layout.TensorShape::to_string() << "\n" | |||||
<< " DType: " << v1.layout.dtype.name() << "\n" | |||||
<< "algo: " << algo_name; | |||||
} | |||||
++it0; | |||||
++it1; | |||||
} | |||||
return ::testing::AssertionSuccess(); | |||||
} | |||||
} // namespace | |||||
::testing::AssertionResult test::__assert_tensor_binary_eq( | |||||
const char* expr0, const char* expr1, const char* expr2, | |||||
const TensorND& v0, const TensorND& v1, | |||||
const Algorithm::Info::Desc& algo) { | |||||
bool shape_match = v0.layout[0] == 1; | |||||
for (size_t i = 1; i < v0.layout.ndim; ++i) { | |||||
shape_match &= v0.layout[i] == v1.layout[i]; | |||||
} | |||||
if (!shape_match) { | |||||
return ::testing::AssertionFailure() | |||||
<< "Shape mismatch\n" | |||||
<< "Value of: " << expr1 << "\n" | |||||
<< " Actual: " << v1.layout.TensorShape::to_string() << "\n" | |||||
<< "Expected: " << expr0 << "\n" | |||||
<< "Which is: " << v0.layout.TensorShape::to_string() << "\n" | |||||
<< "algo: " << algo.name << "\n"; | |||||
} | |||||
if (!v0.layout.is_physical_contiguous() || | |||||
!v1.layout.is_physical_contiguous()) { | |||||
return ::testing::AssertionFailure() | |||||
<< "layout should be physical contiguous\n" | |||||
<< "Value of: " << expr1 << "\n" | |||||
<< " Actual: " << v1.layout.is_physical_contiguous() << "\n" | |||||
<< "Expected: " << expr0 << "\n" | |||||
<< "Which is: " << v0.layout.is_physical_contiguous() << "\n" | |||||
<< "algo: " << algo.name << "\n"; | |||||
} | |||||
auto dtype = v0.layout.dtype; | |||||
if (dtype != v1.layout.dtype) { | |||||
return ::testing::AssertionFailure() | |||||
<< "Data type should match\n" | |||||
<< "Value of: " << expr1 << "\n" | |||||
<< " Actual: " << v1.layout.dtype.name() << "\n" | |||||
<< "Expected: " << expr0 << "\n" | |||||
<< "Which is: " << v0.layout.dtype.name() << "\n" | |||||
<< "algo: " << algo.name << "\n"; | |||||
} | |||||
switch (dtype.enumv()) { | |||||
#define cb(_dt) \ | |||||
case DTypeTrait<_dt>::enumv: \ | |||||
return assert_tensor_binary_eq<DTypeTrait<_dt>::ctype>( \ | |||||
expr0, expr1, expr2, v0, v1, algo.name); | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
#undef cb | |||||
default : megdnn_trap(); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,396 @@ | |||||
/** | |||||
* \file dnn/test/common/accuracy_shake_checker.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/common/utils.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/index.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace { | |||||
template <class Opr> | |||||
struct BatchTrait { | |||||
//! index of batch in tensor, 3 for CHWN4 e.g. | |||||
static size_t index_of_batch(const typename Opr::Param&) { return 0; } | |||||
//! indices contain batch in inputs and outputs, src(0) dst(2) for conv e.g. | |||||
static std::vector<size_t> indices_contain_batch; | |||||
static std::vector<size_t> indices_contain_batch_broadcast; | |||||
}; | |||||
template <class Opr> | |||||
std::vector<size_t> BatchTrait<Opr>::indices_contain_batch = {}; | |||||
template <class Opr> | |||||
std::vector<size_t> BatchTrait<Opr>::indices_contain_batch_broadcast = {}; | |||||
#define DEFAULT_INDEX_OF_BATCH(opr) \ | |||||
static size_t index_of_batch(const opr::Param&) { return 0; } | |||||
#define CONV_INDEX_OF_BATCH(opr) \ | |||||
static size_t index_of_batch(const opr::Param& p) { \ | |||||
if (p.format == opr::Param::Format::CHWN4) { \ | |||||
return 3; \ | |||||
} \ | |||||
return 0; \ | |||||
} | |||||
#define OPR_WITHOUT_INPUT_BROADCAST(INDEX_OF_BATCH, opr, idxs, idxs_brdcst) \ | |||||
template <> \ | |||||
struct BatchTrait<opr> { \ | |||||
INDEX_OF_BATCH(opr) \ | |||||
static std::vector<size_t> indices_contain_batch; \ | |||||
static std::vector<size_t> indices_contain_batch_broadcast; \ | |||||
}; \ | |||||
std::vector<size_t> BatchTrait<opr>::indices_contain_batch = idxs; \ | |||||
std::vector<size_t> BatchTrait<opr>::indices_contain_batch_broadcast = \ | |||||
idxs_brdcst; | |||||
OPR_WITHOUT_INPUT_BROADCAST(DEFAULT_INDEX_OF_BATCH, | |||||
megdnn::Convolution3DForward, | |||||
(std::initializer_list<size_t>{0, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(DEFAULT_INDEX_OF_BATCH, | |||||
megdnn::Convolution3DBackwardData, | |||||
(std::initializer_list<size_t>{1, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(DEFAULT_INDEX_OF_BATCH, | |||||
megdnn::Convolution3DBackwardFilter, | |||||
(std::initializer_list<size_t>{0, 1}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(DEFAULT_INDEX_OF_BATCH, megdnn::BatchedMatrixMul, | |||||
(std::initializer_list<size_t>{0, 1, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::ConvolutionForward, | |||||
(std::initializer_list<size_t>{0, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, | |||||
megdnn::ConvolutionBackwardData, | |||||
(std::initializer_list<size_t>{1, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, | |||||
megdnn::ConvolutionBackwardFilter, | |||||
(std::initializer_list<size_t>{0, 1}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::LocalShareForward, | |||||
(std::initializer_list<size_t>{0, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::LocalShareBackwardData, | |||||
(std::initializer_list<size_t>{1, 2}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, | |||||
megdnn::LocalShareBackwardFilter, | |||||
(std::initializer_list<size_t>{0, 1}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::DeformableConvForward, | |||||
(std::initializer_list<size_t>{0, 2, 3, 4}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST( | |||||
CONV_INDEX_OF_BATCH, megdnn::DeformableConvBackwardData, | |||||
(std::initializer_list<size_t>{0, 2, 3, 4, 5, 6, 7}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, | |||||
megdnn::DeformableConvBackwardFilter, | |||||
(std::initializer_list<size_t>{0, 1, 2, 3}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::BatchConvBiasForward, | |||||
(std::initializer_list<size_t>{0, 1, 2, 3, 4}), {}) | |||||
OPR_WITHOUT_INPUT_BROADCAST(CONV_INDEX_OF_BATCH, megdnn::ConvBiasForward, | |||||
(std::initializer_list<size_t>{0, 3, 4}), {2}) | |||||
#undef OPR_WITHOUT_INPUT_BROADCAST | |||||
#undef DEFAULT_INDEX_OF_BATCH | |||||
#undef CONV_INDEX_OF_BATCH | |||||
template <class Opr> | |||||
struct LayoutsModifier { | |||||
static void on(TensorLayoutArray& layouts, const typename Opr::Param& p, | |||||
size_t new_batch_size) { | |||||
size_t batch_index = BatchTrait<Opr>::index_of_batch(p); | |||||
for (size_t index : BatchTrait<Opr>::indices_contain_batch) { | |||||
layouts.at(index)[batch_index] = new_batch_size; | |||||
} | |||||
for (size_t index : BatchTrait<Opr>::indices_contain_batch_broadcast) { | |||||
if (!check_bias_share_in_channel(layouts.at(index), p.format)) { | |||||
layouts.at(index)[batch_index] = new_batch_size; | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#define OPR_NO_BIAS(opr) \ | |||||
template <> \ | |||||
struct LayoutsModifier<opr> { \ | |||||
static void on(TensorLayoutArray& layouts, \ | |||||
const typename opr::Param& p, size_t new_batch_size) { \ | |||||
size_t batch_index = BatchTrait<opr>::index_of_batch(p); \ | |||||
for (size_t index : BatchTrait<opr>::indices_contain_batch) { \ | |||||
layouts.at(index)[batch_index] = new_batch_size; \ | |||||
} \ | |||||
} \ | |||||
}; | |||||
OPR_NO_BIAS(megdnn::Convolution3D) | |||||
OPR_NO_BIAS(megdnn::BatchedMatrixMul) | |||||
#undef OPR_NO_BIAS | |||||
template <> | |||||
struct LayoutsModifier<megdnn::MatrixMul> { | |||||
public: | |||||
static void on(TensorLayoutArray& layouts, | |||||
const megdnn::MatrixMul::Param& p, | |||||
size_t new_batch_size) { | |||||
assert(!p.transposeA && !p.transposeB); | |||||
MEGDNN_MARK_USED_VAR(p); | |||||
layouts.at(0)[0] = new_batch_size; | |||||
layouts.at(2)[0] = new_batch_size; | |||||
} | |||||
}; | |||||
template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | |||||
class AlgoGenerator { | |||||
public: | |||||
AlgoGenerator(ExecutionPolicyAlgoName name) | |||||
: m_policy_name{name} {} | |||||
std::vector<Algorithm::Info::Desc> operator()( | |||||
Opr* opr, const CheckerHelper::TensorValueArray& arr) { | |||||
TensorLayoutArray layouts; | |||||
for (auto&& val : arr) { | |||||
layouts.push_back(val.layout); | |||||
} | |||||
std::vector<Algorithm::Info::Desc> ret; | |||||
megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | |||||
for (auto algo_info : | |||||
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||||
opr, layouts)) { | |||||
if (!(algo_info.attribute & | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) && | |||||
std::regex_match( | |||||
algo_info.desc.name, | |||||
std::regex("(.*)(" + m_policy_name.name + ")(.*)"))) { | |||||
ret.push_back(algo_info.desc); | |||||
} else { | |||||
continue; | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
private: | |||||
ExecutionPolicyAlgoName m_policy_name; | |||||
}; | |||||
} // namespace | |||||
::testing::AssertionResult __assert_tensor_binary_eq( | |||||
const char* expr0, const char* expr1, const char* expr2, | |||||
const TensorND& v0, const TensorND& v1, | |||||
const Algorithm::Info::Desc& algo); | |||||
template <typename Opr, typename Proxy = OprProxy<Opr>> | |||||
class AccuracyShakeChecker : public CheckerHelper { | |||||
public: | |||||
static constexpr int arity_in = OprArityTrait<Opr>::arity_in; | |||||
using Param = typename Opr::Param; | |||||
using BeforeExecCallback = std::function<std::vector<Algorithm::Info::Desc>( | |||||
Opr*, const TensorValueArray&)>; | |||||
AccuracyShakeChecker(Handle* handle, bool check_dispatch = false) | |||||
: CheckerHelper(handle, check_dispatch), | |||||
m_before_exec_callback{AlgoGenerator<Opr>("")}, | |||||
m_param(Param()) {} | |||||
TensorLayoutArray make_layouts(const TensorShapeArray& shapes) { | |||||
TensorLayoutArray layouts(shapes.size()); | |||||
for (size_t i = 0; i < shapes.size(); ++i) { | |||||
DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] | |||||
: dtype::Float32()); | |||||
TensorFormat fmt = | |||||
(m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{}); | |||||
layouts[i] = TensorLayout(shapes[i], dt, fmt); | |||||
} | |||||
return layouts; | |||||
} | |||||
/*! | |||||
* \brief execute opr on current param/dtype/rng config | |||||
* \param shapes input/output shapes, which would be passed as | |||||
* arguments to Opr::deduce_layout | |||||
* | |||||
* Checker would construct TensorLayout vectors from shapes and dtypes, | |||||
* and call exec(TensorLayoutArray &). | |||||
*/ | |||||
AccuracyShakeChecker& exec(const TensorShapeArray& shapes) { | |||||
exec(make_layouts(shapes)); | |||||
return *this; | |||||
} | |||||
void exec(TensorLayoutArray layouts); | |||||
AccuracyShakeChecker& set_param(Param p) { | |||||
m_param = p; | |||||
opr()->param() = p; | |||||
return *this; | |||||
} | |||||
AccuracyShakeChecker& set_dtype(size_t idx, DType dtype) { | |||||
m_dtype[idx] = dtype; | |||||
return *this; | |||||
} | |||||
AccuracyShakeChecker& set_rng(size_t idx, RNG* rng) { | |||||
m_rng[idx] = rng; | |||||
return *this; | |||||
} | |||||
//! set a callback to be invoked before executing the operator | |||||
AccuracyShakeChecker& set_before_exec_callback( | |||||
const BeforeExecCallback& cb) { | |||||
m_before_exec_callback = cb; | |||||
return *this; | |||||
} | |||||
AccuracyShakeChecker& reset_before_exec_callback() { | |||||
m_before_exec_callback = nullptr; | |||||
return *this; | |||||
} | |||||
//! get the opr impl so setting other than param() can be modified | |||||
Opr* opr() { | |||||
if (!m_opr_cur) { | |||||
m_opr_cur = m_handle_cur->create_operator<Opr>(); | |||||
} | |||||
return m_opr_cur.get(); | |||||
} | |||||
private: | |||||
BeforeExecCallback m_before_exec_callback; | |||||
Param m_param; | |||||
Proxy m_proxy; | |||||
std::unique_ptr<Opr> m_opr_cur; | |||||
std::shared_ptr<TensorValueArray> m_tensors_cur_host, | |||||
m_tensors_single_batch_host; | |||||
void init_host_values(); | |||||
void check_tensors_ignore_batch( | |||||
const TensorValueArray& tensors_single_batch, | |||||
const TensorValueArray& tensors, const Algorithm::Info::Desc& desc); | |||||
}; | |||||
template <typename Opr, typename Proxy> | |||||
void AccuracyShakeChecker<Opr, Proxy>::exec(TensorLayoutArray layouts) { | |||||
auto opr_cur = this->opr(); | |||||
opr_cur->param() = m_param; | |||||
m_proxy.deduce_layout(opr_cur, layouts); | |||||
TensorLayoutArray layouts_single_batch = layouts; | |||||
for (size_t i=0; i<layouts_single_batch.size(); ++i) { | |||||
ASSERT_TRUE(layouts[i].is_physical_contiguous()) | |||||
<< "layouts should be physical contiguous " | |||||
<< layouts[i].to_string(); | |||||
} | |||||
ASSERT_TRUE(0 == BatchTrait<Opr>::index_of_batch(opr_cur->param())) | |||||
<< "index of batch should be 0 "; | |||||
LayoutsModifier<Opr>::on(layouts_single_batch, opr_cur->param(), 1); | |||||
// allocate input | |||||
auto tensors_single_batch_storage = | |||||
alloc_tensors(m_handle_cur, layouts_single_batch, 0); | |||||
m_tensors_single_batch_host = | |||||
alloc_tensors(m_handle_naive.get(), layouts_single_batch, 0); | |||||
auto tensors_cur_storage = alloc_tensors(m_handle_cur, layouts, 0); | |||||
m_tensors_cur_host = | |||||
alloc_tensors(m_handle_naive.get(), layouts, 0); | |||||
auto &&tensors_single_batch = *tensors_single_batch_storage; | |||||
auto &&tensors_single_batch_host = *m_tensors_single_batch_host; | |||||
auto &&tensors_cur = *tensors_cur_storage; | |||||
auto &&tensors_cur_host = *m_tensors_cur_host; | |||||
// allocate output | |||||
auto tensors_single_batch_storage_out = | |||||
alloc_tensors(m_handle_naive.get(), layouts_single_batch, 0); | |||||
auto tensors_cur_storage_out = | |||||
alloc_tensors(m_handle_naive.get(), layouts, 0); | |||||
auto &&tensors_single_batch_out = *tensors_single_batch_storage_out; | |||||
auto &&tensors_cur_out = *tensors_cur_storage_out; | |||||
init_host_values(); | |||||
copy_tensors_to_device(tensors_cur, tensors_cur_host); | |||||
copy_tensors_to_device(tensors_single_batch, tensors_single_batch_host); | |||||
std::vector<Algorithm::Info::Desc> algo_desc; | |||||
if (m_before_exec_callback) { | |||||
algo_desc = m_before_exec_callback(opr_cur, tensors_cur); | |||||
} else { | |||||
algo_desc.push_back({}); | |||||
} | |||||
for (size_t i = 0; i < algo_desc.size(); ++i) { | |||||
opr_cur->execution_policy().algo = algo_desc[i]; | |||||
m_proxy.exec(opr_cur, tensors_cur); | |||||
m_proxy.exec(opr_cur, tensors_single_batch); | |||||
copy_tensors_from_device(tensors_cur_out, tensors_cur); | |||||
copy_tensors_from_device(tensors_single_batch_out, | |||||
tensors_single_batch); | |||||
check_tensors_ignore_batch(tensors_single_batch_out, tensors_cur_out, | |||||
algo_desc[i]); | |||||
} | |||||
} | |||||
template <typename Opr, typename Proxy> | |||||
void AccuracyShakeChecker<Opr, Proxy>::init_host_values() { | |||||
size_t index_of_batch = 0; | |||||
auto &&tensors_single_batch = *m_tensors_single_batch_host; | |||||
auto &&tensors_cur = *m_tensors_cur_host; | |||||
for (size_t i = 0; i < arity_in; ++i) { | |||||
auto &&tensor_single_batch = tensors_single_batch[i]; | |||||
auto &&tensor_cur = tensors_cur[i]; | |||||
auto rng = m_rng[i]; | |||||
if (!rng) | |||||
rng = m_default_rng.get(); | |||||
rng->gen(tensor_single_batch); | |||||
dt_byte* raw_storage_cur = static_cast<dt_byte*>(tensor_cur.raw_ptr) + | |||||
tensor_cur.layout.span().low_byte; | |||||
dt_byte* raw_storage_single_batch = | |||||
static_cast<dt_byte*>(tensor_single_batch.raw_ptr) + | |||||
tensor_single_batch.layout.span().low_byte; | |||||
const size_t step = tensor_single_batch.layout.span().dist_byte(); | |||||
if (tensor_cur.layout.eq_shape(tensor_single_batch.layout)) { | |||||
memcpy(raw_storage_cur, raw_storage_single_batch, step); | |||||
} else { | |||||
ASSERT_TRUE(1 == tensor_single_batch.layout[index_of_batch]) | |||||
<< "bad batch size " | |||||
<< tensor_single_batch.layout[index_of_batch]; | |||||
for (size_t b=0; b<tensor_cur.layout[index_of_batch]; ++b) { | |||||
memcpy(raw_storage_cur, raw_storage_single_batch, step); | |||||
raw_storage_cur += step; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <typename Opr, typename Proxy> | |||||
void AccuracyShakeChecker<Opr, Proxy>::check_tensors_ignore_batch( | |||||
const TensorValueArray& tensors_single_batch, | |||||
const TensorValueArray& tensors, const Algorithm::Info::Desc& algo) { | |||||
for (size_t i = 0; i < tensors_single_batch.size(); ++i) { | |||||
if (tensors_single_batch[i].layout.ndim == 0 || | |||||
tensors_single_batch[i].layout.eq_shape(tensors[i].layout)) | |||||
continue; | |||||
ASSERT_PRED_FORMAT3(::megdnn::test::__assert_tensor_binary_eq, | |||||
tensors_single_batch[i], tensors[i], algo); | |||||
} | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -19,50 +19,6 @@ using namespace megdnn; | |||||
using namespace test; | using namespace test; | ||||
namespace { | namespace { | ||||
bool good_float(float val) { | |||||
return std::isfinite(val); | |||||
} | |||||
bool good_float(int) { | |||||
return true; | |||||
} | |||||
bool good_float(dt_qint8) { | |||||
return true; | |||||
} | |||||
bool good_float(dt_qint16) { | |||||
return true; | |||||
} | |||||
bool good_float(dt_quint8) { | |||||
return true; | |||||
} | |||||
bool good_float(dt_qint32) { | |||||
return true; | |||||
} | |||||
// A hack for the (x+0) promote to int trick on dt_quint8. | |||||
int operator +(dt_quint8 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_uint8(); | |||||
} | |||||
int operator +(dt_qint32 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_int32(); | |||||
} | |||||
int operator +(dt_qint8 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return int8_t(lhs); | |||||
} | |||||
int operator +(dt_qint16 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_int16(); | |||||
} | |||||
template<typename ctype, class Iter> | template<typename ctype, class Iter> | ||||
::testing::AssertionResult assert_tensor_eq_with_iter( | ::testing::AssertionResult assert_tensor_eq_with_iter( | ||||
@@ -86,6 +86,7 @@ protected: | |||||
size_t m_offset = 0; | size_t m_offset = 0; | ||||
CheckerHelper(Handle* handle, bool check_dispatch = true); | CheckerHelper(Handle* handle, bool check_dispatch = true); | ||||
~CheckerHelper() noexcept; | ~CheckerHelper() noexcept; | ||||
using OprExec = std::function<void(const TensorValueArray&)>; | using OprExec = std::function<void(const TensorValueArray&)>; | ||||
@@ -100,14 +101,15 @@ protected: | |||||
void enable_contig_naive() { m_enable_contig_naive = true; } | void enable_contig_naive() { m_enable_contig_naive = true; } | ||||
private: | |||||
std::shared_ptr<TensorValueArray> m_tensors_naive; | |||||
void init_naive_values(); | |||||
void copy_tensors_to_device(const TensorValueArray& dest, | void copy_tensors_to_device(const TensorValueArray& dest, | ||||
const TensorValueArray& src); | const TensorValueArray& src); | ||||
void copy_tensors_from_device(const TensorValueArray& dest, | void copy_tensors_from_device(const TensorValueArray& dest, | ||||
const TensorValueArray& src); | const TensorValueArray& src); | ||||
private: | |||||
std::shared_ptr<TensorValueArray> m_tensors_naive; | |||||
void init_naive_values(); | |||||
void check_tensors(const TensorValueArray& expected, | void check_tensors(const TensorValueArray& expected, | ||||
const TensorValueArray& computed); | const TensorValueArray& computed); | ||||
}; | }; | ||||
@@ -311,6 +311,51 @@ public: | |||||
size_t get_cpu_count(); | size_t get_cpu_count(); | ||||
static inline bool good_float(float val) { | |||||
return std::isfinite(val); | |||||
} | |||||
static inline bool good_float(int) { | |||||
return true; | |||||
} | |||||
static inline bool good_float(dt_qint8) { | |||||
return true; | |||||
} | |||||
static inline bool good_float(dt_qint16) { | |||||
return true; | |||||
} | |||||
static inline bool good_float(dt_quint8) { | |||||
return true; | |||||
} | |||||
static inline bool good_float(dt_qint32) { | |||||
return true; | |||||
} | |||||
// A hack for the (x+0) promote to int trick on dt_quint8. | |||||
static inline int operator+(dt_quint8 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_uint8(); | |||||
} | |||||
static inline int operator+(dt_qint32 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_int32(); | |||||
} | |||||
static inline int operator+(dt_qint8 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return int8_t(lhs); | |||||
} | |||||
static inline int operator+(dt_qint16 lhs, int rhs) { | |||||
megdnn_assert(rhs == 0, "unexpected rhs"); | |||||
return lhs.as_int16(); | |||||
} | |||||
} // namespace test | } // namespace test | ||||
static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { | static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { | ||||
@@ -0,0 +1,247 @@ | |||||
/** | |||||
* \file dnn/test/cuda/accuracy_shake.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "test/cuda/fixture.h" | |||||
#include "test/cuda/utils.h" | |||||
#include "test/common/rng.h" | |||||
#include "test/common/accuracy_shake_checker.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(CUDA, SHAKE_CONV_BIAS_FORWARD) { | |||||
require_compute_capability(6, 1); | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle_cuda()); | |||||
NormalRNG default_rng; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_rng(0, &default_rng) | |||||
.set_rng(1, &default_rng); | |||||
// convolution | |||||
checker.exec({{64, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}}); | |||||
// convbias without z | |||||
checker.exec({{64, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
// convbias with z | |||||
checker.exec({{64, 16, 32, 32}, | |||||
{64, 16, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{64, 64, 30, 30}, | |||||
{}}); | |||||
ConvBias::Param param; | |||||
// group | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
checker.set_param(param); | |||||
checker.exec({{64, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}}); | |||||
checker.exec({{64, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
checker.exec({{64, 16, 32, 32}, | |||||
{2, 32, 8, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{64, 64, 30, 30}, | |||||
{}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_CONV_BIAS_FORWARD_QS8_NCHW) { | |||||
require_compute_capability(6, 1); | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle_cuda()); | |||||
UniformIntRNG int_rng{-128, 127}; | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(3, dtype::QuantizedS8(0.25f)) | |||||
.set_dtype(4, dtype::QuantizedS8(0.25f)) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng) | |||||
.set_rng(2, &int_rng) | |||||
.set_rng(3, &int_rng); | |||||
// convolution | |||||
checker.exec({{64, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}}); | |||||
// convbias without z | |||||
checker.exec({{64, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
// convbias with z | |||||
checker.exec({{64, 16, 32, 32}, | |||||
{64, 16, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{64, 64, 30, 30}, | |||||
{}}); | |||||
// group | |||||
ConvBias::Param param; | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
checker.set_param(param); | |||||
checker.exec({{64, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}}); | |||||
checker.exec({{64, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
checker.exec({{64, 16, 32, 32}, | |||||
{2, 32, 8, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{64, 64, 30, 30}, | |||||
{}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_CONV_BIAS_FORWARD_QS8_NHWC) { | |||||
require_compute_capability(6, 1); | |||||
UniformIntRNG int_rng{-50, 50}; | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle_cuda()); | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NHWC; | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(4, dtype::QuantizedS8(60.25f)) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng) | |||||
.set_rng(2, &int_rng) | |||||
.set_param(param); | |||||
checker.exec({{20, 32, 32, 4}, {24, 1, 1, 4}, {1, 1, 1, 24}, {}, {}}); | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
checker.set_param(param).exec( | |||||
{{20, 32, 32, 16}, {4, 4, 1, 1, 4}, {1, 1, 1, 16}, {}, {}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_CONV_BIAS_FORWARD_QS8_NCHWX) { | |||||
using Format = ConvBias::Param::Format; | |||||
require_compute_capability(6, 1); | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle_cuda()); | |||||
UniformIntRNG int_rng{-5, 5}; | |||||
UniformFloatRNG float_rng{-50, 50}; | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.2f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.3f)) | |||||
.set_dtype(2, dtype::QuantizedS32(1.2 * 1.3f)) | |||||
.set_dtype(3, dtype::QuantizedS8(1.3f)) | |||||
.set_dtype(4, dtype::QuantizedS8(1.3f)) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng) | |||||
.set_rng(2, &int_rng) | |||||
.set_rng(3, &int_rng); | |||||
auto run = [&](const TensorShapeArray& shapes, const Format& format) { | |||||
ConvBias::Param param; | |||||
param.format = format; | |||||
checker.set_param(param).exec( | |||||
{shapes[0], shapes[1], shapes[2], {}, {}}); | |||||
}; | |||||
run({{20, 2, 24, 24, 4}, {24, 2, 3, 3, 4}, {1, 6, 1, 1, 4}}, Format::NCHW4); | |||||
run({{20, 1, 24, 24, 32}, {64, 1, 3, 3, 32}, {1, 2, 1, 1, 32}}, | |||||
Format::NCHW32); | |||||
run({{16, 4, 23, 40, 4}, | |||||
{32, 4, 3, 3, 4}, | |||||
{1, 1, 1, 1, 32}}, Format::NCHW4_NCHW32); | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.9980927f)) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(3, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng) | |||||
.set_rng(2, &float_rng) | |||||
.set_rng(3, &float_rng); | |||||
run({{16, 4, 92, 160, 4}, {20, 4, 3, 3, 4}, {1, 20, 1, 1}}, | |||||
Format::NCHW4_NCHW); | |||||
} | |||||
TEST_F(CUDA, SHAKE_MATRIX_MUL_FORWARD) { | |||||
AccuracyShakeChecker<MatrixMul> checker(handle_cuda()); | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.exec({{50, 100}, {100, 60}, {}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_BATCH_CONV_BIAS_QS8) { | |||||
require_compute_capability(6, 1); | |||||
AccuracyShakeChecker<BatchConvBiasForward> checker(handle_cuda()); | |||||
UniformIntRNG const_rng{1, 1}; | |||||
UniformIntRNG rng{-5, 5}; | |||||
UniformIntRNG bias_rng{-50, 50}; | |||||
checker.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_rng(2, &rng) | |||||
.set_rng(3, &rng) | |||||
.set_dtype(0, dtype::QuantizedS8{1.2f}) | |||||
.set_dtype(1, dtype::QuantizedS8{1.3f}) | |||||
.set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f}) | |||||
.set_dtype(3, dtype::QuantizedS8{1.1f}) | |||||
.set_dtype(4, dtype::QuantizedS8{1.1f}); | |||||
param::BatchConvBias param; | |||||
param.pad_h = 2, param.pad_w = 1; | |||||
param.stride_h = 1, param.stride_w = 2; | |||||
param.format = param::BatchConvBias::Format::NCHW4; | |||||
checker.set_param(param).exec({{32, 4, 24, 24, 4}, | |||||
{32, 32, 4, 1, 1, 4}, | |||||
{1, 8, 1, 1, 4}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_BATCHED_MATRIX_MUL) { | |||||
AccuracyShakeChecker<BatchedMatrixMul> checker(handle_cuda()); | |||||
UniformIntRNG int_rng{-127, 127}; | |||||
NormalRNG default_rng; | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.2f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.3f)) | |||||
.set_dtype(2, {}) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng); | |||||
checker.exec({{20, 424, 368}, {20, 368, 256}, {20, 424, 256}}); | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_rng(0, &default_rng) | |||||
.set_rng(1, &default_rng); | |||||
checker.exec({{20, 424, 368}, {20, 368, 256}, {20, 424, 256}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_CONVOLUTION3D_FORWARD) { | |||||
AccuracyShakeChecker<Convolution3DForward> checker(handle_cuda()); | |||||
NormalRNG default_rng; | |||||
float scale = 1.0f / sqrt(5); | |||||
UniformFloatRNG rng(scale, 2 * scale); | |||||
param::Convolution3D param; | |||||
param.mode = param::Convolution3D::Mode::CROSS_CORRELATION; | |||||
param.stride_d = param.stride_h = param.stride_w = 2; | |||||
param.pad_d = param.pad_h = param.pad_w = 0; | |||||
param.dilate_d = param.dilate_h = param.dilate_w = 1; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_rng(0, &default_rng) | |||||
.set_rng(1, &default_rng) | |||||
.set_param(param) | |||||
.exec({{20, 5, 12, 12, 16}, {5, 5, 3, 3, 3}, {}}); | |||||
} | |||||
TEST_F(CUDA, SHAKE_LOCAL_SHARE) { | |||||
AccuracyShakeChecker<LocalShare> checker(handle_cuda()); | |||||
using Param = LocalShare::Param; | |||||
Param param; | |||||
param.spatial_groups_h = param.spatial_groups_w = 3; | |||||
checker.set_param(param); | |||||
checker.exec({{20, 16, 32, 32}, {3, 3, 16, 3, 3, 64}, {}}); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -20,6 +20,7 @@ | |||||
#include "test/common/rng.h" | #include "test/common/rng.h" | ||||
#include "test/cuda/benchmark.h" | #include "test/cuda/benchmark.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "test/common/accuracy_shake_checker.h" | |||||
#define V1(x) #x | #define V1(x) #x | ||||
#define V(x) V1(x) | #define V(x) V1(x) | ||||
@@ -0,0 +1,104 @@ | |||||
/** | |||||
* \file dnn/test/x86/accuracy_shake.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/x86/fixture.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/accuracy_shake_checker.h" | |||||
#include "test/common/convolution.h" | |||||
#include "test/common/rng.h" | |||||
#include "test/common/tensor.h" | |||||
#include "test/common/workspace_wrapper.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(X86, SHAKE_CONV_BIAS_FORWARD) { | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle()); | |||||
NormalRNG default_rng; | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_rng(0, &default_rng) | |||||
.set_rng(1, &default_rng); | |||||
checker.set_before_exec_callback(AlgoGenerator<ConvBiasForward>("X86")); | |||||
// convolution | |||||
checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}}); | |||||
// convbias without z | |||||
checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
// convbias with z | |||||
checker.exec({{6, 16, 32, 32}, | |||||
{64, 16, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{6, 64, 30, 30}, | |||||
{}}); | |||||
// group | |||||
ConvBias::Param param; | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
checker.set_param(param); | |||||
checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}}); | |||||
checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
checker.exec({{6, 16, 32, 32}, | |||||
{2, 32, 8, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{6, 64, 30, 30}, | |||||
{}}); | |||||
} | |||||
TEST_F(X86, SHAKE_CONV_BIAS_FORWARD_INT8) { | |||||
AccuracyShakeChecker<ConvBiasForward> checker(handle()); | |||||
UniformIntRNG rng{-50, 50}; | |||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(3, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(4, {}) | |||||
.set_rng(0, &rng) | |||||
.set_rng(1, &rng) | |||||
.set_rng(2, &rng); | |||||
checker.set_before_exec_callback(AlgoGenerator<ConvBiasForward>("X86")); | |||||
// convolution | |||||
checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {}, {}, {}}); | |||||
// convbias without z | |||||
checker.exec({{6, 16, 32, 32}, {64, 16, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
// convbias with z | |||||
checker.exec({{6, 16, 32, 32}, | |||||
{64, 16, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{6, 64, 30, 30}, | |||||
{}}); | |||||
// group | |||||
ConvBias::Param param; | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
checker.set_param(param); | |||||
checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {}, {}, {}}); | |||||
checker.exec({{6, 16, 32, 32}, {2, 32, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
checker.exec({{6, 16, 32, 32}, | |||||
{2, 32, 8, 3, 3}, | |||||
{1, 64, 1, 1}, | |||||
{6, 64, 30, 30}, | |||||
{}}); | |||||
} | |||||
TEST_F(X86, SHAKE_MATRIX_MUL_FORWARD) { | |||||
AccuracyShakeChecker<MatrixMul> checker(handle()); | |||||
checker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.exec({{20, 100}, {100, 60}, {}}); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -15,6 +15,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "test/common/benchmarker.h" | #include "test/common/benchmarker.h" | ||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/accuracy_shake_checker.h" | |||||
#include "test/common/convolution.h" | #include "test/common/convolution.h" | ||||
#include "test/common/rng.h" | #include "test/common/rng.h" | ||||
#include "test/common/tensor.h" | #include "test/common/tensor.h" | ||||
@@ -18,9 +18,7 @@ | |||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/oprs/linalg.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace opr { | namespace opr { | ||||
@@ -46,39 +44,6 @@ namespace opr { | |||||
// clang-format on | // clang-format on | ||||
template <typename Opr> | template <typename Opr> | ||||
struct OprArityTrait; | |||||
template <typename Opr, int _arity_in, int _arity_out> | |||||
struct OprArityTraitTmpl { | |||||
static constexpr int arity_in = _arity_in; | |||||
static constexpr int arity_out = _arity_out; | |||||
static constexpr int arity = arity_in + arity_out; | |||||
}; | |||||
#define INST_ARITY(_Opr, _in, _out) \ | |||||
template <> \ | |||||
struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; | |||||
INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DForward, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareForward, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution, 2, 1); | |||||
INST_ARITY(megdnn::DeformableConvForward, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||||
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||||
INST_ARITY(megdnn::ConvBias, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||||
INST_ARITY(megdnn::MatrixMul, 2, 1); | |||||
INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); | |||||
#undef INST_ARITY | |||||
template <typename Opr> | |||||
constexpr bool opr_supports_preprocess() { | constexpr bool opr_supports_preprocess() { | ||||
return std::is_same<Opr, megdnn::ConvolutionForward>::value || | return std::is_same<Opr, megdnn::ConvolutionForward>::value || | ||||
std::is_same<Opr, megdnn::ConvBias>::value; | std::is_same<Opr, megdnn::ConvBias>::value; | ||||