Browse Source

refactor(dnn): add get_algorithm_from_desc interface

GitOrigin-RevId: 6d211ca167
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
85fa988348
38 changed files with 373 additions and 194 deletions
  1. +3
    -0
      dnn/include/megdnn/oprs/base.h
  2. +6
    -5
      dnn/src/common/algo_base.h
  3. +2
    -2
      dnn/src/common/algo_chooser.h
  4. +1
    -1
      dnn/src/cuda/batch_conv_bias/opr_impl.h
  5. +1
    -1
      dnn/src/cuda/batched_matrix_mul/opr_impl.h
  6. +1
    -1
      dnn/src/cuda/conv_bias/opr_impl.h
  7. +22
    -0
      dnn/src/cuda/convolution/opr_impl.cpp
  8. +4
    -2
      dnn/src/cuda/convolution/opr_impl.h
  9. +3
    -3
      dnn/src/cuda/convolution3d/opr_impl.h
  10. +3
    -3
      dnn/src/cuda/deformable_conv/opr_impl.h
  11. +3
    -3
      dnn/src/cuda/local_share/opr_impl.h
  12. +1
    -1
      dnn/src/cuda/matrix_mul/opr_impl.h
  13. +1
    -2
      dnn/src/fallback/batched_matrix_mul/opr_impl.h
  14. +3
    -3
      dnn/src/fallback/conv_bias/opr_impl.cpp
  15. +1
    -1
      dnn/src/fallback/conv_bias/opr_impl.h
  16. +6
    -6
      dnn/src/fallback/convolution/opr_impl.cpp
  17. +2
    -2
      dnn/src/fallback/convolution/opr_impl.h
  18. +3
    -2
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  19. +2
    -1
      dnn/src/fallback/matrix_mul/opr_impl.h
  20. +8
    -0
      dnn/src/naive/batch_conv_bias/opr_impl.cpp
  21. +2
    -0
      dnn/src/naive/batch_conv_bias/opr_impl.h
  22. +9
    -0
      dnn/src/naive/batched_matrix_mul/opr_impl.cpp
  23. +2
    -0
      dnn/src/naive/batched_matrix_mul/opr_impl.h
  24. +9
    -0
      dnn/src/naive/conv_bias/opr_impl.cpp
  25. +2
    -0
      dnn/src/naive/conv_bias/opr_impl.h
  26. +26
    -0
      dnn/src/naive/convolution/convolution.cpp
  27. +6
    -0
      dnn/src/naive/convolution/opr_impl.h
  28. +114
    -84
      dnn/src/naive/convolution3d/convolution3d.cpp
  29. +63
    -65
      dnn/src/naive/convolution3d/opr_impl.h
  30. +12
    -0
      dnn/src/naive/deformable_conv/opr_impl.h
  31. +27
    -0
      dnn/src/naive/local_share/opr_impl.cpp
  32. +5
    -1
      dnn/src/naive/local_share/opr_impl.h
  33. +8
    -0
      dnn/src/naive/matrix_mul/opr_impl.cpp
  34. +2
    -0
      dnn/src/naive/matrix_mul/opr_impl.h
  35. +1
    -1
      dnn/src/rocm/batched_matrix_mul/opr_impl.h
  36. +3
    -3
      dnn/src/rocm/convolution/opr_impl.h
  37. +2
    -1
      dnn/src/rocm/matrix_mul/opr_impl.h
  38. +4
    -0
      src/opr/test/dnn/convolution.cpp

+ 3
- 0
dnn/include/megdnn/oprs/base.h View File

@@ -188,6 +188,7 @@ public:
using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmInfo = detail::Algorithm::Info;
using AlgorithmDesc = detail::Algorithm::Info::Desc; using AlgorithmDesc = detail::Algorithm::Info::Desc;
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;

/*! /*!
* \brief get a string representation for current algorithm set; * \brief get a string representation for current algorithm set;
* *
@@ -209,6 +210,8 @@ public:
return m_execution_policy; return m_execution_policy;
} }


virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0;

protected: protected:
~MultiAlgoOpr() = default; ~MultiAlgoOpr() = default;




+ 6
- 5
dnn/src/common/algo_base.h View File

@@ -38,11 +38,12 @@ namespace megdnn {
return algo_pack().all_algos_map().at(desc); \ return algo_pack().all_algos_map().at(desc); \
} }


#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::Algorithm* _opr::get_algorithm_from_desc( \
const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
} }


/** /**


+ 2
- 2
dnn/src/common/algo_chooser.h View File

@@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false); false);
} }
return opr->get_algo_from_desc(ret.desc);
return static_cast<typename Opr::AlgoBase*>(
opr->get_algorithm_from_desc(ret.desc));
} }


/*! /*!
@@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
*/ */
template <class Opr, typename... Args> template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo; auto set = opr->execution_policy().algo;
if (set.valid()) { if (set.valid()) {
return opr->algo_pack().construct_and_get_algo(set.desc); return opr->algo_pack().construct_and_get_algo(set.desc);


+ 1
- 1
dnn/src/cuda/batch_conv_bias/opr_impl.h View File

@@ -35,7 +35,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 1
- 1
dnn/src/cuda/batched_matrix_mul/opr_impl.h View File

@@ -39,7 +39,7 @@ public:


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,


+ 1
- 1
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -69,7 +69,7 @@ public:


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,


+ 22
- 0
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
workspace_limit_in_bytes, reproducible); workspace_limit_in_bytes, reproducible);
} }


ConvolutionForwardImpl::Algorithm*
ConvolutionForwardImpl::get_algorithm_from_desc(
const ConvolutionForward::AlgorithmDesc& desc) {
auto conv_param = param();
auto convbias_opr = this->handle()->create_operator<ConvBiasForward>();
convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY,
conv_param.mode,
conv_param.sparse,
conv_param.format,
conv_param.pad_h,
conv_param.pad_w,
conv_param.stride_h,
conv_param.stride_w,
conv_param.dilate_h,
conv_param.dilate_w,
conv_param.compute_mode};
convbias_opr->execution_policy() = {this->execution_policy().algo};

return static_cast<ConvBiasForwardImpl*>(convbias_opr.get())
->get_algorithm_from_desc(desc);
}

std::vector<ConvolutionForwardImpl::Algorithm*> std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,


+ 4
- 2
dnn/src/cuda/convolution/opr_impl.h View File

@@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
megdnn_throw("cuda exec_preprocess has not implemeted yet"); megdnn_throw("cuda exec_preprocess has not implemeted yet");
} }


Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;

protected: protected:
struct ConvBiasExtraData{ struct ConvBiasExtraData{
std::unique_ptr<ConvBiasForward> convbias_opr; std::unique_ptr<ConvBiasForward> convbias_opr;
@@ -98,7 +100,7 @@ public:


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -152,7 +154,7 @@ public:


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 3
- 3
dnn/src/cuda/convolution3d/opr_impl.h View File

@@ -42,7 +42,7 @@ public:
class AlgoGroupConvGeneral; class AlgoGroupConvGeneral;
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -92,7 +92,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -143,7 +143,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 3
- 3
dnn/src/cuda/deformable_conv/opr_impl.h View File

@@ -46,7 +46,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -97,7 +97,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -151,7 +151,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 3
- 3
dnn/src/cuda/local_share/opr_impl.h View File

@@ -33,7 +33,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -65,7 +65,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -98,7 +98,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 1
- 1
dnn/src/cuda/matrix_mul/opr_impl.h View File

@@ -46,7 +46,7 @@ public:
static const AlgoPack& algo_pack() { static const AlgoPack& algo_pack() {
return sm_algo_pack; return sm_algo_pack;
} }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


protected: protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,


+ 1
- 2
dnn/src/fallback/batched_matrix_mul/opr_impl.h View File

@@ -29,8 +29,7 @@ public:
class AlgoDefault; class AlgoDefault;
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }

static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;


private: private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(


+ 3
- 3
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
return algos; return algos;
} }


ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) { if (!desc.valid()) {
return nullptr; return nullptr;
} else { } else {
@@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(


ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) { const NCBKernSizeParam& param, size_t workspace_size) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo; return algo;
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||


+ 1
- 1
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -381,7 +381,7 @@ private:


bool is_naive_algo(ConvBiasImpl::Algorithm* algo); bool is_naive_algo(ConvBiasImpl::Algorithm* algo);


Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


//! get algorithm set by user or by heuristic //! get algorithm set by user or by heuristic
Algorithm* get_algorithm( Algorithm* get_algorithm(


+ 6
- 6
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
return ret; return ret;
} }


ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) { if (!desc.valid()) {
return nullptr; return nullptr;
} else { } else {
@@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc(


ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) { const NCBKernSizeParam& param, size_t workspace_size) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo; return algo;
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||
@@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
} }


ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
ConvolutionBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
if (!desc.valid()) { if (!desc.valid()) {
return nullptr; return nullptr;
} else { } else {
@@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc(


ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) {
return algo; return algo;
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||


+ 2
- 2
dnn/src/fallback/convolution/opr_impl.h View File

@@ -284,7 +284,7 @@ private:
NCBKernSizeParam m_prev_selected_algo_sizep; NCBKernSizeParam m_prev_selected_algo_sizep;
Algorithm* m_prev_selected_algo = nullptr; Algorithm* m_prev_selected_algo = nullptr;


Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
bool is_naive_algo(ConvolutionImpl::Algorithm* algo); bool is_naive_algo(ConvolutionImpl::Algorithm* algo);
Algorithm* get_algorithm( Algorithm* get_algorithm(
const NCBKernSizeParam& param, const NCBKernSizeParam& param,
@@ -493,7 +493,7 @@ private:
class AlgoDirect; class AlgoDirect;
class AlgoMatrixMul; class AlgoMatrixMul;
class AlgoPack; class AlgoPack;
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


public: public:
//! maintain all the algos of in the opr of fallback //! maintain all the algos of in the opr of fallback


+ 3
- 2
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
return gemv_algos; return gemv_algos;
} }


MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc(
MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) { const AlgorithmDesc& desc) {
if (!desc.valid()) { if (!desc.valid()) {
return nullptr; return nullptr;
@@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) { size_t workspace_limit_in_bytes, bool reproducible) {
auto kern_size_param = make_kern_size_param(A, B, C); auto kern_size_param = make_kern_size_param(A, B, C);
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
if (auto algo = static_cast<AlgoBase*>(
get_algorithm_from_desc(execution_policy().algo.desc))) {
megdnn_assert(algo->get_workspace(kern_size_param) < megdnn_assert(algo->get_workspace(kern_size_param) <
workspace_limit_in_bytes); workspace_limit_in_bytes);
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo,


+ 2
- 1
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -238,7 +238,8 @@ private:
class AlgoPack; class AlgoPack;
//! maintain all the algos of in the opr of fallback //! maintain all the algos of in the opr of fallback
static const AlgoPack& algo_pack(); static const AlgoPack& algo_pack();
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;

public: public:


/** /**


+ 8
- 0
dnn/src/naive/batch_conv_bias/opr_impl.cpp View File

@@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }


BatchConvBiasForward::Algorithm*
BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_batch_conv_bias_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 0
dnn/src/naive/batch_conv_bias/opr_impl.h View File

@@ -39,6 +39,8 @@ public:
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;

const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,


+ 9
- 0
dnn/src/naive/batched_matrix_mul/opr_impl.cpp View File

@@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
->default_batched_matmul_fwd_algo(); ->default_batched_matmul_fwd_algo();
} }


BatchedMatrixMulForward::Algorithm*
BatchedMatrixMulForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_batched_matmul_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn




+ 2
- 0
dnn/src/naive/batched_matrix_mul/opr_impl.h View File

@@ -34,6 +34,8 @@ public:
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override; bool /* reproducible */) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;

const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }


private: private:


+ 9
- 0
dnn/src/naive/conv_bias/opr_impl.cpp View File

@@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }


ConvBiasForward::Algorithm*
ConvBiasForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

const char* ConvBiasForwardImpl::get_algorithm_set_name() const { const char* ConvBiasForwardImpl::get_algorithm_set_name() const {
return "DEFAULT"; return "DEFAULT";
} }


+ 2
- 0
dnn/src/naive/conv_bias/opr_impl.h View File

@@ -64,6 +64,8 @@ public:
_megdnn_workspace) override {} _megdnn_workspace) override {}


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;

Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
}; };


void handle_z_inp_and_activation_naive( void handle_z_inp_and_activation_naive(


+ 26
- 0
dnn/src/naive/convolution/convolution.cpp View File

@@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }


ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<ConvolutionBackwardData::Algorithm *> std::vector<ConvolutionBackwardData::Algorithm *>
ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &) const TensorLayout &, const TensorLayout &)
@@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
return algo; return algo;
} }


ConvolutionBackwardData::Algorithm*
ConvolutionBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<ConvolutionBackwardFilter::Algorithm *> std::vector<ConvolutionBackwardFilter::Algorithm *>
ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &) const TensorLayout &, const TensorLayout &)
@@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
return algo; return algo;
} }


ConvolutionBackwardFilter::Algorithm*
ConvolutionBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

const char* ConvolutionForwardImpl::get_algorithm_set_name() const { const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
return "DEFAULT"; return "DEFAULT";
} }


+ 6
- 0
dnn/src/naive/convolution/opr_impl.h View File

@@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward {
return {}; return {};
} }


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;

const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
}; };


@@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
const TensorLayout&) override; const TensorLayout&) override;


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;

Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
}; };


class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
@@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
const TensorLayout&) override; const TensorLayout&) override;


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;

Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
}; };


} // namespace naive } // namespace naive


+ 114
- 84
dnn/src/naive/convolution3d/convolution3d.cpp View File

@@ -6,15 +6,15 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "./opr_impl.h"
#include "./helper.h" #include "./helper.h"
#include "./opr_impl.h"


#include "src/naive/handle.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "megdnn/dtype.h" #include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"


#include <cstring> #include <cstring>


@@ -25,93 +25,95 @@ using namespace megdnn;
using namespace naive; using namespace naive;


void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, void Convolution3DForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace)
{
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) {
auto filter_meta = check_exec(
src.layout, filter.layout, dst.layout, workspace.size);
switch (param().data_type) {
case Param::DataType::FLOAT:
#define cb(dt) do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::forward< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, filter, dst, filter_meta); \
); \
return; \
} \
} while(0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
auto filter_meta = check_exec(src.layout, filter.layout, dst.layout,
workspace.size);
switch (param().data_type) {
case Param::DataType::FLOAT:
#define cb(dt) \
do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::forward< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, filter, dst, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb
break;
case Param::DataType::FLOAT_IO16xC32:
MEGDNN_INC_FLOAT16(
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()),
convolution3d::forward<
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>(
src, filter, dst, filter_meta);));
return;
break;
case Param::DataType::FLOAT_IO16xC32:
MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN(
static_cast<HandleImpl*>(handle()),
convolution3d::forward<
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA
dt_float32>(src, filter, dst,
filter_meta);));
return;
}
megdnn_assert_internal(0);
} }
megdnn_assert_internal(0);

} MIDOUT_END();
MIDOUT_END();
} }


void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
auto filter_meta = check_exec(
filter.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
auto filter_meta =
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) \
do { \
if (filter.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_data< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
filter, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb


megdnn_assert_internal(0); megdnn_assert_internal(0);
} }
void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
auto filter_meta = check_exec(
src.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while(0);
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
auto filter_meta =
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
#define cb(dt) \
do { \
if (src.layout.dtype == dt()) { \
using ctype = DTypeTrait<dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<HandleImpl*>(handle()), \
convolution3d::backward_filter< \
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
src, diff, grad, filter_meta);); \
return; \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb #undef cb


megdnn_assert_internal(0); megdnn_assert_internal(0);
} }


std::vector<Convolution3DForward::Algorithm *>
Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv3d_fwd_algo()};
std::vector<Convolution3DForward::Algorithm*>
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()};
} }


Convolution3DForward::Algorithm* Convolution3DForward::Algorithm*
@@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }


std::vector<Convolution3DBackwardData::Algorithm *>
Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl *>(handle())->default_conv3d_bwd_data_algo()};
Convolution3DForward::Algorithm*
Convolution3DForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<Convolution3DBackwardData::Algorithm*>
Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()};
} }


Convolution3DBackwardData::Algorithm* Convolution3DBackwardData::Algorithm*
@@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic(
return algo; return algo;
} }


std::vector<Convolution3DBackwardFilter::Algorithm *>
Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &,
const TensorLayout &, const TensorLayout &)
{
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_filter_algo()};
Convolution3DBackwardData::Algorithm*
Convolution3DBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<Convolution3DBackwardFilter::Algorithm*>
Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) {
return {static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo()};
} }


Convolution3DBackwardFilter::Algorithm* Convolution3DBackwardFilter::Algorithm*
@@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic(
return algo; return algo;
} }


Convolution3DBackwardFilter::Algorithm*
Convolution3DBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_conv3d_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

const char* Convolution3DForwardImpl::get_algorithm_set_name() const { const char* Convolution3DForwardImpl::get_algorithm_set_name() const {
return "DEFAULT"; return "DEFAULT";
} }


+ 63
- 65
dnn/src/naive/convolution3d/opr_impl.h View File

@@ -6,81 +6,79 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"


namespace megdnn { namespace megdnn {
namespace naive { namespace naive {
class Convolution3DForwardImpl: public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const char* get_algorithm_set_name() const override;
class Convolution3DForwardImpl : public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
}; };


class Convolution3DBackwardDataImpl: public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}


const char* get_algorithm_set_name() const override;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
}; };


class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
const char* get_algorithm_set_name() const override;
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override;
}; };



} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 12
- 0
dnn/src/naive/deformable_conv/opr_impl.h View File

@@ -48,6 +48,10 @@ public:
return "DEFORMABLE_CONV2_NAIVE"; return "DEFORMABLE_CONV2_NAIVE";
}; };


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}

void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
@@ -84,6 +88,10 @@ public:
return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE";
}; };


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}

void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad, _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
_megdnn_tensor_out filter_grad, _megdnn_tensor_out filter_grad,
@@ -130,6 +138,10 @@ public:
return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; return "DEFORMABLE_CONV2_BWD_DATA_NAIVE";
}; };


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override {
return {};
}

void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
_megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,


+ 27
- 0
dnn/src/naive/local_share/opr_impl.cpp View File

@@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic(
return algo; return algo;
} }


LocalShareForward::Algorithm*
LocalShareForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<LocalShareBackwardData::Algorithm*> std::vector<LocalShareBackwardData::Algorithm*>
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&, const TensorLayout&,
@@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic(
return algo; return algo;
} }


LocalShareBackwardData::Algorithm*
LocalShareBackwardDataImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_data_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

std::vector<LocalShareBackwardFilter::Algorithm*> std::vector<LocalShareBackwardFilter::Algorithm*>
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&,
const TensorLayout&, const TensorLayout&,
@@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic(
return algo; return algo;
} }


LocalShareBackwardFilter::Algorithm*
LocalShareBackwardFilterImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret = static_cast<HandleImpl*>(handle())
->default_local_share_bwd_filter_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 5
- 1
dnn/src/naive/local_share/opr_impl.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
@@ -35,6 +36,7 @@ public:
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; bool /*reproducible*/) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
}; };


@@ -59,6 +61,7 @@ public:
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; bool /*reproducible*/) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
}; };


@@ -83,6 +86,7 @@ public:
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; bool /*reproducible*/) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }
}; };




+ 8
- 0
dnn/src/naive/matrix_mul/opr_impl.cpp View File

@@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic(
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
} }


MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc(
const AlgorithmDesc& desc) {
Algorithm* ret =
static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo();
megdnn_assert(desc == ret->info().desc);
return ret;
}

} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn




+ 2
- 0
dnn/src/naive/matrix_mul/opr_impl.h View File

@@ -35,6 +35,8 @@ public:
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /* reproducible */) override; bool /* reproducible */) override;


Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;

const char* get_algorithm_set_name() const override { return "DEFAULT"; } const char* get_algorithm_set_name() const override { return "DEFAULT"; }


private: private:


+ 1
- 1
dnn/src/rocm/batched_matrix_mul/opr_impl.h View File

@@ -29,8 +29,8 @@ public:
class AlgoBlas; class AlgoBlas;
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
private: private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*A*/, const TensorLayout& /*B*/,


+ 3
- 3
dnn/src/rocm/convolution/opr_impl.h View File

@@ -66,7 +66,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


private: private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -112,7 +112,7 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;


private: private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -158,7 +158,7 @@ public:


class AlgoPack; class AlgoPack;


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


private: private:


+ 2
- 1
dnn/src/rocm/matrix_mul/opr_impl.h View File

@@ -29,7 +29,7 @@ public:
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;


private: private:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
@@ -41,6 +41,7 @@ private:
const TensorLayout& /*C*/, const TensorLayout& /*C*/,
size_t /*workspace_limit_in_bytes*/, size_t /*workspace_limit_in_bytes*/,
bool /*reproducible*/) override; bool /*reproducible*/) override;

const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return "ROCM MATMUL"; return "ROCM MATMUL";
} }


+ 4
- 0
src/opr/test/dnn/convolution.cpp View File

@@ -2204,6 +2204,10 @@ public:
const TensorLayout& p2, const TensorLayout& p2,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible)); bool reproducible));

MOCK_METHOD1(get_algorithm_from_desc,
Algorithm*(const AlgorithmDesc&));

protected: protected:
const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return m_algorithm_set_name; return m_algorithm_set_name;


Loading…
Cancel
Save