GitOrigin-RevId: 6d211ca167
tags/v1.3.0
@@ -188,6 +188,7 @@ public: | |||
using AlgorithmInfo = detail::Algorithm::Info; | |||
using AlgorithmDesc = detail::Algorithm::Info::Desc; | |||
using Algorithm = detail::Algorithm; | |||
/*! | |||
* \brief get a string representation for current algorithm set; | |||
* | |||
@@ -209,6 +210,8 @@ public: | |||
return m_execution_policy; | |||
} | |||
virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
@@ -38,11 +38,12 @@ namespace megdnn { | |||
return algo_pack().all_algos_map().at(desc); \ | |||
} | |||
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
algo_pack().all_algos_map().end()); \ | |||
return algo_pack().all_algos_map().at(desc); \ | |||
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
_opr::Algorithm* _opr::get_algorithm_from_desc( \ | |||
const AlgorithmDesc& desc) { \ | |||
megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
algo_pack().all_algos_map().end()); \ | |||
return algo_pack().all_algos_map().at(desc); \ | |||
} | |||
/** | |||
@@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
false); | |||
} | |||
return opr->get_algo_from_desc(ret.desc); | |||
return static_cast<typename Opr::AlgoBase*>( | |||
opr->get_algorithm_from_desc(ret.desc)); | |||
} | |||
/*! | |||
@@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
*/ | |||
template <class Opr, typename... Args> | |||
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||
typename Opr::AlgorithmInfo ret; | |||
auto set = opr->execution_policy().algo; | |||
if (set.valid()) { | |||
return opr->algo_pack().construct_and_get_algo(set.desc); | |||
@@ -35,7 +35,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -39,7 +39,7 @@ public: | |||
bool is_thread_safe() const override { return true; } | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
@@ -69,7 +69,7 @@ public: | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
@@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||
workspace_limit_in_bytes, reproducible); | |||
} | |||
ConvolutionForwardImpl::Algorithm* | |||
ConvolutionForwardImpl::get_algorithm_from_desc( | |||
const ConvolutionForward::AlgorithmDesc& desc) { | |||
auto conv_param = param(); | |||
auto convbias_opr = this->handle()->create_operator<ConvBiasForward>(); | |||
convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | |||
conv_param.mode, | |||
conv_param.sparse, | |||
conv_param.format, | |||
conv_param.pad_h, | |||
conv_param.pad_w, | |||
conv_param.stride_h, | |||
conv_param.stride_w, | |||
conv_param.dilate_h, | |||
conv_param.dilate_w, | |||
conv_param.compute_mode}; | |||
convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||
return static_cast<ConvBiasForwardImpl*>(convbias_opr.get()) | |||
->get_algorithm_from_desc(desc); | |||
} | |||
std::vector<ConvolutionForwardImpl::Algorithm*> | |||
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
@@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
megdnn_throw("cuda exec_preprocess has not implemeted yet"); | |||
} | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
struct ConvBiasExtraData{ | |||
std::unique_ptr<ConvBiasForward> convbias_opr; | |||
@@ -98,7 +100,7 @@ public: | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -152,7 +154,7 @@ public: | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -42,7 +42,7 @@ public: | |||
class AlgoGroupConvGeneral; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -92,7 +92,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -143,7 +143,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -46,7 +46,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -97,7 +97,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -151,7 +151,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -33,7 +33,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -65,7 +65,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -98,7 +98,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -46,7 +46,7 @@ public: | |||
static const AlgoPack& algo_pack() { | |||
return sm_algo_pack; | |||
} | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
protected: | |||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
@@ -29,8 +29,7 @@ public: | |||
class AlgoDefault; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
return algos; | |||
} | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
const AlgorithmDesc& desc) const { | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
if (!desc.valid()) { | |||
return nullptr; | |||
} else { | |||
@@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||
const NCBKernSizeParam& param, size_t workspace_size) { | |||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
return algo; | |||
} | |||
if (!m_prev_selected_algo || | |||
@@ -381,7 +381,7 @@ private: | |||
bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | |||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
//! get algorithm set by user or by heuristic | |||
Algorithm* get_algorithm( | |||
@@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { | |||
return ret; | |||
} | |||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||
const AlgorithmDesc& desc) const { | |||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
if (!desc.valid()) { | |||
return nullptr; | |||
} else { | |||
@@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||
const NCBKernSizeParam& param, size_t workspace_size) { | |||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
return algo; | |||
} | |||
if (!m_prev_selected_algo || | |||
@@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | |||
} | |||
ConvolutionBackwardDataImpl::Algorithm* | |||
ConvolutionBackwardDataImpl::get_algo_from_desc( | |||
const AlgorithmDesc& desc) const { | |||
ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
if (!desc.valid()) { | |||
return nullptr; | |||
} else { | |||
@@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc( | |||
ConvolutionBackwardDataImpl::Algorithm* | |||
ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | |||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
return algo; | |||
} | |||
if (!m_prev_selected_algo || | |||
@@ -284,7 +284,7 @@ private: | |||
NCBKernSizeParam m_prev_selected_algo_sizep; | |||
Algorithm* m_prev_selected_algo = nullptr; | |||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
bool is_naive_algo(ConvolutionImpl::Algorithm* algo); | |||
Algorithm* get_algorithm( | |||
const NCBKernSizeParam& param, | |||
@@ -493,7 +493,7 @@ private: | |||
class AlgoDirect; | |||
class AlgoMatrixMul; | |||
class AlgoPack; | |||
Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
public: | |||
//! maintain all the algos of in the opr of fallback | |||
@@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||
return gemv_algos; | |||
} | |||
MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( | |||
MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
if (!desc.valid()) { | |||
return nullptr; | |||
@@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_limit_in_bytes, bool reproducible) { | |||
auto kern_size_param = make_kern_size_param(A, B, C); | |||
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
if (auto algo = static_cast<AlgoBase*>( | |||
get_algorithm_from_desc(execution_policy().algo.desc))) { | |||
megdnn_assert(algo->get_workspace(kern_size_param) < | |||
workspace_limit_in_bytes); | |||
auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, | |||
@@ -238,7 +238,8 @@ private: | |||
class AlgoPack; | |||
//! maintain all the algos of in the opr of fallback | |||
static const AlgoPack& algo_pack(); | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
public: | |||
/** | |||
@@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
BatchConvBiasForward::Algorithm* | |||
BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) { | |||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
->default_batch_conv_bias_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -39,6 +39,8 @@ public: | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
private: | |||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
@@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
->default_batched_matmul_fwd_algo(); | |||
} | |||
BatchedMatrixMulForward::Algorithm* | |||
BatchedMatrixMulForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
->default_batched_matmul_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
@@ -34,6 +34,8 @@ public: | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
private: | |||
@@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
ConvBiasForward::Algorithm* | |||
ConvBiasForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
const char* ConvBiasForwardImpl::get_algorithm_set_name() const { | |||
return "DEFAULT"; | |||
} | |||
@@ -64,6 +64,8 @@ public: | |||
_megdnn_workspace) override {} | |||
const char* get_algorithm_set_name() const override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
}; | |||
void handle_z_inp_and_activation_naive( | |||
@@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<ConvolutionBackwardData::Algorithm *> | |||
ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||
const TensorLayout &, const TensorLayout &) | |||
@@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
ConvolutionBackwardData::Algorithm* | |||
ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<ConvolutionBackwardFilter::Algorithm *> | |||
ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||
const TensorLayout &, const TensorLayout &) | |||
@@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
ConvolutionBackwardFilter::Algorithm* | |||
ConvolutionBackwardFilterImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
const char* ConvolutionForwardImpl::get_algorithm_set_name() const { | |||
return "DEFAULT"; | |||
} | |||
@@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
return {}; | |||
} | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override; | |||
}; | |||
@@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
const TensorLayout&) override; | |||
const char* get_algorithm_set_name() const override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
}; | |||
class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
@@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
const TensorLayout&) override; | |||
const char* get_algorithm_set_name() const override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
}; | |||
} // namespace naive | |||
@@ -6,15 +6,15 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
#include "./helper.h" | |||
#include "./opr_impl.h" | |||
#include "src/naive/handle.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/utils.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
#include <cstring> | |||
@@ -25,93 +25,95 @@ using namespace megdnn; | |||
using namespace naive; | |||
void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) | |||
{ | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { | |||
auto filter_meta = check_exec( | |||
src.layout, filter.layout, dst.layout, workspace.size); | |||
switch (param().data_type) { | |||
case Param::DataType::FLOAT: | |||
#define cb(dt) do { \ | |||
if (src.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
convolution3d::forward< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
src, filter, dst, filter_meta); \ | |||
); \ | |||
return; \ | |||
} \ | |||
} while(0); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
auto filter_meta = check_exec(src.layout, filter.layout, dst.layout, | |||
workspace.size); | |||
switch (param().data_type) { | |||
case Param::DataType::FLOAT: | |||
#define cb(dt) \ | |||
do { \ | |||
if (src.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<HandleImpl*>(handle()), \ | |||
convolution3d::forward< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
src, filter, dst, filter_meta);); \ | |||
return; \ | |||
} \ | |||
} while (0); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
#undef cb | |||
break; | |||
case Param::DataType::FLOAT_IO16xC32: | |||
MEGDNN_INC_FLOAT16( | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), | |||
convolution3d::forward< | |||
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>( | |||
src, filter, dst, filter_meta);)); | |||
return; | |||
break; | |||
case Param::DataType::FLOAT_IO16xC32: | |||
MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN( | |||
static_cast<HandleImpl*>(handle()), | |||
convolution3d::forward< | |||
dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA | |||
dt_float32>(src, filter, dst, | |||
filter_meta);)); | |||
return; | |||
} | |||
megdnn_assert_internal(0); | |||
} | |||
megdnn_assert_internal(0); | |||
} MIDOUT_END(); | |||
MIDOUT_END(); | |||
} | |||
void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) | |||
{ | |||
auto filter_meta = check_exec( | |||
filter.layout, diff.layout, grad.layout, workspace.size); | |||
#define cb(dt) do { \ | |||
if (filter.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
convolution3d::backward_data< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
filter, diff, grad, filter_meta);); \ | |||
return; \ | |||
} \ | |||
} while(0); | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
auto filter_meta = | |||
check_exec(filter.layout, diff.layout, grad.layout, workspace.size); | |||
#define cb(dt) \ | |||
do { \ | |||
if (filter.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<HandleImpl*>(handle()), \ | |||
convolution3d::backward_data< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
filter, diff, grad, filter_meta);); \ | |||
return; \ | |||
} \ | |||
} while (0); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
#undef cb | |||
megdnn_assert_internal(0); | |||
} | |||
void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) | |||
{ | |||
auto filter_meta = check_exec( | |||
src.layout, diff.layout, grad.layout, workspace.size); | |||
#define cb(dt) do { \ | |||
if (src.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
convolution3d::backward_filter< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
src, diff, grad, filter_meta);); \ | |||
return; \ | |||
} \ | |||
} while(0); | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
auto filter_meta = | |||
check_exec(src.layout, diff.layout, grad.layout, workspace.size); | |||
#define cb(dt) \ | |||
do { \ | |||
if (src.layout.dtype == dt()) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN( \ | |||
static_cast<HandleImpl*>(handle()), \ | |||
convolution3d::backward_filter< \ | |||
ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
src, diff, grad, filter_meta);); \ | |||
return; \ | |||
} \ | |||
} while (0); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
#undef cb | |||
megdnn_assert_internal(0); | |||
} | |||
std::vector<Convolution3DForward::Algorithm *> | |||
Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &, | |||
const TensorLayout &, const TensorLayout &) | |||
{ | |||
return {static_cast<HandleImpl *>(handle())->default_conv3d_fwd_algo()}; | |||
std::vector<Convolution3DForward::Algorithm*> | |||
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||
} | |||
Convolution3DForward::Algorithm* | |||
@@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
std::vector<Convolution3DBackwardData::Algorithm *> | |||
Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||
const TensorLayout &, const TensorLayout &) | |||
{ | |||
return {static_cast<HandleImpl *>(handle())->default_conv3d_bwd_data_algo()}; | |||
Convolution3DForward::Algorithm* | |||
Convolution3DForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<Convolution3DBackwardData::Algorithm*> | |||
Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||
} | |||
Convolution3DBackwardData::Algorithm* | |||
@@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
std::vector<Convolution3DBackwardFilter::Algorithm *> | |||
Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||
const TensorLayout &, const TensorLayout &) | |||
{ | |||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_filter_algo()}; | |||
Convolution3DBackwardData::Algorithm* | |||
Convolution3DBackwardDataImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<Convolution3DBackwardFilter::Algorithm*> | |||
Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) { | |||
return {static_cast<HandleImpl*>(handle()) | |||
->default_conv3d_bwd_filter_algo()}; | |||
} | |||
Convolution3DBackwardFilter::Algorithm* | |||
@@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
Convolution3DBackwardFilter::Algorithm* | |||
Convolution3DBackwardFilterImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
->default_conv3d_bwd_filter_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
const char* Convolution3DForwardImpl::get_algorithm_set_name() const { | |||
return "DEFAULT"; | |||
} | |||
@@ -6,81 +6,79 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class Convolution3DForwardImpl: public Convolution3DForward { | |||
public: | |||
using Convolution3DForward::Convolution3DForward; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &filter, | |||
const TensorLayout &dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
const char* get_algorithm_set_name() const override; | |||
class Convolution3DForwardImpl : public Convolution3DForward { | |||
public: | |||
using Convolution3DForward::Convolution3DForward; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& dst, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override; | |||
}; | |||
class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||
public: | |||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||
void exec(_megdnn_tensor_in filter, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||
public: | |||
using Convolution3DBackwardData::Convolution3DBackwardData; | |||
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& filter, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
const char* get_algorithm_set_name() const override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override; | |||
}; | |||
class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||
public: | |||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) override; | |||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
const char* get_algorithm_set_name() const override; | |||
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||
public: | |||
using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& src, const TensorLayout& diff, | |||
const TensorLayout& grad) override; | |||
Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -48,6 +48,10 @@ public: | |||
return "DEFORMABLE_CONV2_NAIVE"; | |||
}; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
return {}; | |||
} | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
_megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
@@ -84,6 +88,10 @@ public: | |||
return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; | |||
}; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
return {}; | |||
} | |||
void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | |||
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | |||
_megdnn_tensor_out filter_grad, | |||
@@ -130,6 +138,10 @@ public: | |||
return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; | |||
}; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
return {}; | |||
} | |||
void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, | |||
_megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
_megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, | |||
@@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
LocalShareForward::Algorithm* | |||
LocalShareForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<LocalShareBackwardData::Algorithm*> | |||
LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
@@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
LocalShareBackwardData::Algorithm* | |||
LocalShareBackwardDataImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_data_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
std::vector<LocalShareBackwardFilter::Algorithm*> | |||
LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
const TensorLayout&, | |||
@@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||
return algo; | |||
} | |||
LocalShareBackwardFilter::Algorithm* | |||
LocalShareBackwardFilterImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
->default_local_share_bwd_filter_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -35,6 +36,7 @@ public: | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -59,6 +61,7 @@ public: | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -83,6 +86,7 @@ public: | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
}; | |||
@@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
} | |||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc( | |||
const AlgorithmDesc& desc) { | |||
Algorithm* ret = | |||
static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
megdnn_assert(desc == ret->info().desc); | |||
return ret; | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
@@ -35,6 +35,8 @@ public: | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /* reproducible */) override; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
private: | |||
@@ -29,8 +29,8 @@ public: | |||
class AlgoBlas; | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
@@ -66,7 +66,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -112,7 +112,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -158,7 +158,7 @@ public: | |||
class AlgoPack; | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
private: | |||
@@ -29,7 +29,7 @@ public: | |||
class AlgoPack; | |||
static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
private: | |||
std::vector<Algorithm*> get_all_algorithms( | |||
@@ -41,6 +41,7 @@ private: | |||
const TensorLayout& /*C*/, | |||
size_t /*workspace_limit_in_bytes*/, | |||
bool /*reproducible*/) override; | |||
const char* get_algorithm_set_name() const override { | |||
return "ROCM MATMUL"; | |||
} | |||
@@ -2204,6 +2204,10 @@ public: | |||
const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible)); | |||
MOCK_METHOD1(get_algorithm_from_desc, | |||
Algorithm*(const AlgorithmDesc&)); | |||
protected: | |||
const char* get_algorithm_set_name() const override { | |||
return m_algorithm_set_name; | |||