GitOrigin-RevId: 6d211ca167
tags/v1.3.0
@@ -188,6 +188,7 @@ public: | |||||
using AlgorithmInfo = detail::Algorithm::Info; | using AlgorithmInfo = detail::Algorithm::Info; | ||||
using AlgorithmDesc = detail::Algorithm::Info::Desc; | using AlgorithmDesc = detail::Algorithm::Info::Desc; | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
/*! | /*! | ||||
* \brief get a string representation for current algorithm set; | * \brief get a string representation for current algorithm set; | ||||
* | * | ||||
@@ -209,6 +210,8 @@ public: | |||||
return m_execution_policy; | return m_execution_policy; | ||||
} | } | ||||
virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | |||||
protected: | protected: | ||||
~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
@@ -38,11 +38,12 @@ namespace megdnn { | |||||
return algo_pack().all_algos_map().at(desc); \ | return algo_pack().all_algos_map().at(desc); \ | ||||
} | } | ||||
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||||
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||||
algo_pack().all_algos_map().end()); \ | |||||
return algo_pack().all_algos_map().at(desc); \ | |||||
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||||
_opr::Algorithm* _opr::get_algorithm_from_desc( \ | |||||
const AlgorithmDesc& desc) { \ | |||||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||||
algo_pack().all_algos_map().end()); \ | |||||
return algo_pack().all_algos_map().at(desc); \ | |||||
} | } | ||||
/** | /** | ||||
@@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||||
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | ||||
false); | false); | ||||
} | } | ||||
return opr->get_algo_from_desc(ret.desc); | |||||
return static_cast<typename Opr::AlgoBase*>( | |||||
opr->get_algorithm_from_desc(ret.desc)); | |||||
} | } | ||||
/*! | /*! | ||||
@@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||||
*/ | */ | ||||
template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | ||||
typename Opr::AlgorithmInfo ret; | |||||
auto set = opr->execution_policy().algo; | auto set = opr->execution_policy().algo; | ||||
if (set.valid()) { | if (set.valid()) { | ||||
return opr->algo_pack().construct_and_get_algo(set.desc); | return opr->algo_pack().construct_and_get_algo(set.desc); | ||||
@@ -35,7 +35,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -39,7 +39,7 @@ public: | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
@@ -69,7 +69,7 @@ public: | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
workspace_limit_in_bytes, reproducible); | workspace_limit_in_bytes, reproducible); | ||||
} | } | ||||
ConvolutionForwardImpl::Algorithm* | |||||
ConvolutionForwardImpl::get_algorithm_from_desc( | |||||
const ConvolutionForward::AlgorithmDesc& desc) { | |||||
auto conv_param = param(); | |||||
auto convbias_opr = this->handle()->create_operator<ConvBiasForward>(); | |||||
convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | |||||
conv_param.mode, | |||||
conv_param.sparse, | |||||
conv_param.format, | |||||
conv_param.pad_h, | |||||
conv_param.pad_w, | |||||
conv_param.stride_h, | |||||
conv_param.stride_w, | |||||
conv_param.dilate_h, | |||||
conv_param.dilate_w, | |||||
conv_param.compute_mode}; | |||||
convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||||
return static_cast<ConvBiasForwardImpl*>(convbias_opr.get()) | |||||
->get_algorithm_from_desc(desc); | |||||
} | |||||
std::vector<ConvolutionForwardImpl::Algorithm*> | std::vector<ConvolutionForwardImpl::Algorithm*> | ||||
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
@@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
megdnn_throw("cuda exec_preprocess has not implemeted yet"); | megdnn_throw("cuda exec_preprocess has not implemeted yet"); | ||||
} | } | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
struct ConvBiasExtraData{ | struct ConvBiasExtraData{ | ||||
std::unique_ptr<ConvBiasForward> convbias_opr; | std::unique_ptr<ConvBiasForward> convbias_opr; | ||||
@@ -98,7 +100,7 @@ public: | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -152,7 +154,7 @@ public: | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -42,7 +42,7 @@ public: | |||||
class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -92,7 +92,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -143,7 +143,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -46,7 +46,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -97,7 +97,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -151,7 +151,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -33,7 +33,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -65,7 +65,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -98,7 +98,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -46,7 +46,7 @@ public: | |||||
static const AlgoPack& algo_pack() { | static const AlgoPack& algo_pack() { | ||||
return sm_algo_pack; | return sm_algo_pack; | ||||
} | } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
@@ -29,8 +29,7 @@ public: | |||||
class AlgoDefault; | class AlgoDefault; | ||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
private: | private: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||||
return algos; | return algos; | ||||
} | } | ||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||||
const AlgorithmDesc& desc) const { | |||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
if (!desc.valid()) { | if (!desc.valid()) { | ||||
return nullptr; | return nullptr; | ||||
} else { | } else { | ||||
@@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | ||||
const NCBKernSizeParam& param, size_t workspace_size) { | const NCBKernSizeParam& param, size_t workspace_size) { | ||||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||||
return algo; | return algo; | ||||
} | } | ||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
@@ -381,7 +381,7 @@ private: | |||||
bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | ||||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
//! get algorithm set by user or by heuristic | //! get algorithm set by user or by heuristic | ||||
Algorithm* get_algorithm( | Algorithm* get_algorithm( | ||||
@@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { | |||||
return ret; | return ret; | ||||
} | } | ||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||||
const AlgorithmDesc& desc) const { | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
if (!desc.valid()) { | if (!desc.valid()) { | ||||
return nullptr; | return nullptr; | ||||
} else { | } else { | ||||
@@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | ||||
const NCBKernSizeParam& param, size_t workspace_size) { | const NCBKernSizeParam& param, size_t workspace_size) { | ||||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||||
return algo; | return algo; | ||||
} | } | ||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
@@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | |||||
} | } | ||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algo_from_desc( | |||||
const AlgorithmDesc& desc) const { | |||||
ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
if (!desc.valid()) { | if (!desc.valid()) { | ||||
return nullptr; | return nullptr; | ||||
} else { | } else { | ||||
@@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc( | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | ||||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||||
return algo; | return algo; | ||||
} | } | ||||
if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
@@ -284,7 +284,7 @@ private: | |||||
NCBKernSizeParam m_prev_selected_algo_sizep; | NCBKernSizeParam m_prev_selected_algo_sizep; | ||||
Algorithm* m_prev_selected_algo = nullptr; | Algorithm* m_prev_selected_algo = nullptr; | ||||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
bool is_naive_algo(ConvolutionImpl::Algorithm* algo); | bool is_naive_algo(ConvolutionImpl::Algorithm* algo); | ||||
Algorithm* get_algorithm( | Algorithm* get_algorithm( | ||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
@@ -493,7 +493,7 @@ private: | |||||
class AlgoDirect; | class AlgoDirect; | ||||
class AlgoMatrixMul; | class AlgoMatrixMul; | ||||
class AlgoPack; | class AlgoPack; | ||||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
public: | public: | ||||
//! maintain all the algos of in the opr of fallback | //! maintain all the algos of in the opr of fallback | ||||
@@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||||
return gemv_algos; | return gemv_algos; | ||||
} | } | ||||
MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( | |||||
MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | const AlgorithmDesc& desc) { | ||||
if (!desc.valid()) { | if (!desc.valid()) { | ||||
return nullptr; | return nullptr; | ||||
@@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, bool reproducible) { | size_t workspace_limit_in_bytes, bool reproducible) { | ||||
auto kern_size_param = make_kern_size_param(A, B, C); | auto kern_size_param = make_kern_size_param(A, B, C); | ||||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||||
if (auto algo = static_cast<AlgoBase*>( | |||||
get_algorithm_from_desc(execution_policy().algo.desc))) { | |||||
megdnn_assert(algo->get_workspace(kern_size_param) < | megdnn_assert(algo->get_workspace(kern_size_param) < | ||||
workspace_limit_in_bytes); | workspace_limit_in_bytes); | ||||
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, | auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, | ||||
@@ -238,7 +238,8 @@ private: | |||||
class AlgoPack; | class AlgoPack; | ||||
//! maintain all the algos of in the opr of fallback | //! maintain all the algos of in the opr of fallback | ||||
static const AlgoPack& algo_pack(); | static const AlgoPack& algo_pack(); | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
public: | public: | ||||
/** | /** | ||||
@@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
BatchConvBiasForward::Algorithm* | |||||
BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) { | |||||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||||
->default_batch_conv_bias_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -39,6 +39,8 @@ public: | |||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | bool reproducible) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
@@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
->default_batched_matmul_fwd_algo(); | ->default_batched_matmul_fwd_algo(); | ||||
} | } | ||||
BatchedMatrixMulForward::Algorithm* | |||||
BatchedMatrixMulForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||||
->default_batched_matmul_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -34,6 +34,8 @@ public: | |||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override; | bool /* reproducible */) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
private: | private: | ||||
@@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
ConvBiasForward::Algorithm* | |||||
ConvBiasForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
const char* ConvBiasForwardImpl::get_algorithm_set_name() const { | const char* ConvBiasForwardImpl::get_algorithm_set_name() const { | ||||
return "DEFAULT"; | return "DEFAULT"; | ||||
} | } | ||||
@@ -64,6 +64,8 @@ public: | |||||
_megdnn_workspace) override {} | _megdnn_workspace) override {} | ||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
}; | }; | ||||
void handle_z_inp_and_activation_naive( | void handle_z_inp_and_activation_naive( | ||||
@@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<ConvolutionBackwardData::Algorithm *> | std::vector<ConvolutionBackwardData::Algorithm *> | ||||
ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | ||||
const TensorLayout &, const TensorLayout &) | const TensorLayout &, const TensorLayout &) | ||||
@@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
ConvolutionBackwardData::Algorithm* | |||||
ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<ConvolutionBackwardFilter::Algorithm *> | std::vector<ConvolutionBackwardFilter::Algorithm *> | ||||
ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | ||||
const TensorLayout &, const TensorLayout &) | const TensorLayout &, const TensorLayout &) | ||||
@@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
ConvolutionBackwardFilter::Algorithm* | |||||
ConvolutionBackwardFilterImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
const char* ConvolutionForwardImpl::get_algorithm_set_name() const { | const char* ConvolutionForwardImpl::get_algorithm_set_name() const { | ||||
return "DEFAULT"; | return "DEFAULT"; | ||||
} | } | ||||
@@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
return {}; | return {}; | ||||
} | } | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
}; | }; | ||||
@@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
const TensorLayout&) override; | const TensorLayout&) override; | ||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | ||||
@@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
const TensorLayout&) override; | const TensorLayout&) override; | ||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
}; | }; | ||||
} // namespace naive | } // namespace naive | ||||
@@ -6,15 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "./opr_impl.h" | |||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "./opr_impl.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include <cstring> | #include <cstring> | ||||
@@ -25,93 +25,95 @@ using namespace megdnn; | |||||
using namespace naive; | using namespace naive; | ||||
void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, | void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, | ||||
_megdnn_tensor_in filter, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
_megdnn_tensor_in filter, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { | MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { | ||||
auto filter_meta = check_exec( | |||||
src.layout, filter.layout, dst.layout, workspace.size); | |||||
switch (param().data_type) { | |||||
case Param::DataType::FLOAT: | |||||
#define cb(dt) do { \ | |||||
if (src.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||||
convolution3d::forward< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
src, filter, dst, filter_meta); \ | |||||
); \ | |||||
return; \ | |||||
} \ | |||||
} while(0); | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
auto filter_meta = check_exec(src.layout, filter.layout, dst.layout, | |||||
workspace.size); | |||||
switch (param().data_type) { | |||||
case Param::DataType::FLOAT: | |||||
#define cb(dt) \ | |||||
do { \ | |||||
if (src.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<HandleImpl*>(handle()), \ | |||||
convolution3d::forward< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
src, filter, dst, filter_meta);); \ | |||||
return; \ | |||||
} \ | |||||
} while (0); | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
#undef cb | #undef cb | ||||
break; | |||||
case Param::DataType::FLOAT_IO16xC32: | |||||
MEGDNN_INC_FLOAT16( | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), | |||||
convolution3d::forward< | |||||
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>( | |||||
src, filter, dst, filter_meta);)); | |||||
return; | |||||
break; | |||||
case Param::DataType::FLOAT_IO16xC32: | |||||
MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN( | |||||
static_cast<HandleImpl*>(handle()), | |||||
convolution3d::forward< | |||||
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA | |||||
dt_float32>(src, filter, dst, | |||||
filter_meta);)); | |||||
return; | |||||
} | |||||
megdnn_assert_internal(0); | |||||
} | } | ||||
megdnn_assert_internal(0); | |||||
} MIDOUT_END(); | |||||
MIDOUT_END(); | |||||
} | } | ||||
void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, | void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, | ||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
auto filter_meta = check_exec( | |||||
filter.layout, diff.layout, grad.layout, workspace.size); | |||||
#define cb(dt) do { \ | |||||
if (filter.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||||
convolution3d::backward_data< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
filter, diff, grad, filter_meta);); \ | |||||
return; \ | |||||
} \ | |||||
} while(0); | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
auto filter_meta = | |||||
check_exec(filter.layout, diff.layout, grad.layout, workspace.size); | |||||
#define cb(dt) \ | |||||
do { \ | |||||
if (filter.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<HandleImpl*>(handle()), \ | |||||
convolution3d::backward_data< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
filter, diff, grad, filter_meta);); \ | |||||
return; \ | |||||
} \ | |||||
} while (0); | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | ||||
#undef cb | #undef cb | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
} | } | ||||
void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | ||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) | |||||
{ | |||||
auto filter_meta = check_exec( | |||||
src.layout, diff.layout, grad.layout, workspace.size); | |||||
#define cb(dt) do { \ | |||||
if (src.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||||
convolution3d::backward_filter< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
src, diff, grad, filter_meta);); \ | |||||
return; \ | |||||
} \ | |||||
} while(0); | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
auto filter_meta = | |||||
check_exec(src.layout, diff.layout, grad.layout, workspace.size); | |||||
#define cb(dt) \ | |||||
do { \ | |||||
if (src.layout.dtype == dt()) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
MEGDNN_DISPATCH_CPU_KERN( \ | |||||
static_cast<HandleImpl*>(handle()), \ | |||||
convolution3d::backward_filter< \ | |||||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||||
src, diff, grad, filter_meta);); \ | |||||
return; \ | |||||
} \ | |||||
} while (0); | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | ||||
#undef cb | #undef cb | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
} | } | ||||
std::vector<Convolution3DForward::Algorithm *> | |||||
Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl *>(handle())->default_conv3d_fwd_algo()}; | |||||
std::vector<Convolution3DForward::Algorithm*> | |||||
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||||
} | } | ||||
Convolution3DForward::Algorithm* | Convolution3DForward::Algorithm* | ||||
@@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
std::vector<Convolution3DBackwardData::Algorithm *> | |||||
Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl *>(handle())->default_conv3d_bwd_data_algo()}; | |||||
Convolution3DForward::Algorithm* | |||||
Convolution3DForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<Convolution3DBackwardData::Algorithm*> | |||||
Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||||
} | } | ||||
Convolution3DBackwardData::Algorithm* | Convolution3DBackwardData::Algorithm* | ||||
@@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
std::vector<Convolution3DBackwardFilter::Algorithm *> | |||||
Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_filter_algo()}; | |||||
Convolution3DBackwardData::Algorithm* | |||||
Convolution3DBackwardDataImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<Convolution3DBackwardFilter::Algorithm*> | |||||
Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_conv3d_bwd_filter_algo()}; | |||||
} | } | ||||
Convolution3DBackwardFilter::Algorithm* | Convolution3DBackwardFilter::Algorithm* | ||||
@@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
Convolution3DBackwardFilter::Algorithm* | |||||
Convolution3DBackwardFilterImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||||
->default_conv3d_bwd_filter_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
const char* Convolution3DForwardImpl::get_algorithm_set_name() const { | const char* Convolution3DForwardImpl::get_algorithm_set_name() const { | ||||
return "DEFAULT"; | return "DEFAULT"; | ||||
} | } | ||||
@@ -6,81 +6,79 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class Convolution3DForwardImpl: public Convolution3DForward { | |||||
public: | |||||
using Convolution3DForward::Convolution3DForward; | |||||
void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in filter, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
const TensorLayout &filter, | |||||
const TensorLayout &dst) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
const char* get_algorithm_set_name() const override; | |||||
class Convolution3DForwardImpl : public Convolution3DForward { | |||||
public: | |||||
using Convolution3DForward::Convolution3DForward; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override; | |||||
}; | }; | ||||
class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||||
public: | |||||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||||
void exec(_megdnn_tensor_in filter, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) override; | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||||
public: | |||||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||||
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
const char* get_algorithm_set_name() const override; | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override; | |||||
}; | }; | ||||
class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||||
public: | |||||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||||
void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) override; | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
const char* get_algorithm_set_name() const override; | |||||
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||||
public: | |||||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad, | |||||
size_t workspace_limit_in_bytes, | |||||
bool reproducible) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override; | |||||
}; | }; | ||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -48,6 +48,10 @@ public: | |||||
return "DEFORMABLE_CONV2_NAIVE"; | return "DEFORMABLE_CONV2_NAIVE"; | ||||
}; | }; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||||
return {}; | |||||
} | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | ||||
_megdnn_tensor_in offset, _megdnn_tensor_in mask, | _megdnn_tensor_in offset, _megdnn_tensor_in mask, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | ||||
@@ -84,6 +88,10 @@ public: | |||||
return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; | return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; | ||||
}; | }; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||||
return {}; | |||||
} | |||||
void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | ||||
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | ||||
_megdnn_tensor_out filter_grad, | _megdnn_tensor_out filter_grad, | ||||
@@ -130,6 +138,10 @@ public: | |||||
return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; | return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; | ||||
}; | }; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||||
return {}; | |||||
} | |||||
void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, | void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, | ||||
_megdnn_tensor_in offset, _megdnn_tensor_in mask, | _megdnn_tensor_in offset, _megdnn_tensor_in mask, | ||||
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, | _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, | ||||
@@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
LocalShareForward::Algorithm* | |||||
LocalShareForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<LocalShareBackwardData::Algorithm*> | std::vector<LocalShareBackwardData::Algorithm*> | ||||
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | ||||
const TensorLayout&, | const TensorLayout&, | ||||
@@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
LocalShareBackwardData::Algorithm* | |||||
LocalShareBackwardDataImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_data_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
std::vector<LocalShareBackwardFilter::Algorithm*> | std::vector<LocalShareBackwardFilter::Algorithm*> | ||||
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | ||||
const TensorLayout&, | const TensorLayout&, | ||||
@@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||||
return algo; | return algo; | ||||
} | } | ||||
LocalShareBackwardFilter::Algorithm* | |||||
LocalShareBackwardFilterImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_filter_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -35,6 +36,7 @@ public: | |||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | bool /*reproducible*/) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -59,6 +61,7 @@ public: | |||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | bool /*reproducible*/) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -83,6 +86,7 @@ public: | |||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | bool /*reproducible*/) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
}; | }; | ||||
@@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||||
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | ||||
} | } | ||||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc( | |||||
const AlgorithmDesc& desc) { | |||||
Algorithm* ret = | |||||
static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||||
megdnn_assert(desc == ret->info().desc); | |||||
return ret; | |||||
} | |||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -35,6 +35,8 @@ public: | |||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /* reproducible */) override; | bool /* reproducible */) override; | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
private: | private: | ||||
@@ -29,8 +29,8 @@ public: | |||||
class AlgoBlas; | class AlgoBlas; | ||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
private: | private: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
@@ -66,7 +66,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
private: | private: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -112,7 +112,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
private: | private: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -158,7 +158,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
private: | private: | ||||
@@ -29,7 +29,7 @@ public: | |||||
class AlgoPack; | class AlgoPack; | ||||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | static const AlgoPack& algo_pack() { return sm_algo_pack; } | ||||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||||
private: | private: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
@@ -41,6 +41,7 @@ private: | |||||
const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
bool /*reproducible*/) override; | bool /*reproducible*/) override; | ||||
const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
return "ROCM MATMUL"; | return "ROCM MATMUL"; | ||||
} | } | ||||
@@ -2204,6 +2204,10 @@ public: | |||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible)); | bool reproducible)); | ||||
MOCK_METHOD1(get_algorithm_from_desc, | |||||
Algorithm*(const AlgorithmDesc&)); | |||||
protected: | protected: | ||||
const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
return m_algorithm_set_name; | return m_algorithm_set_name; | ||||