GitOrigin-RevId: e3734e4531
release-1.6
@@ -315,7 +315,7 @@ public: | |||||
/*! | /*! | ||||
* \brief get a string representation for current algorithm set; | * \brief get a string representation for current algorithm set; | ||||
* | * | ||||
* get_all_algorithms() may return different algorithms only if | |||||
* get_all_algorithms_safe() may return different algorithms only if | |||||
* algorithm set name differs. This is used for checking cache | * algorithm set name differs. This is used for checking cache | ||||
* validity. | * validity. | ||||
*/ | */ | ||||
@@ -354,6 +354,15 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1) { | |||||
std::vector<AlgorithmInfo> ret; | |||||
for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | |||||
ret.emplace_back(algo->info()); | |||||
} | |||||
return ret; | |||||
} | |||||
/** | /** | ||||
* \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
* algorithm by heuristic. | * algorithm by heuristic. | ||||
@@ -378,6 +387,8 @@ protected: | |||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& p0, const TensorLayout& p1) = 0; | const TensorLayout& p0, const TensorLayout& p1) = 0; | ||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -412,6 +423,16 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2) { | |||||
std::vector<AlgorithmInfo> ret; | |||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | |||||
ret.emplace_back(algo->info()); | |||||
} | |||||
return ret; | |||||
} | |||||
/** | /** | ||||
* \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
* algorithm by heuristic. | * algorithm by heuristic. | ||||
@@ -438,6 +459,9 @@ protected: | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2) = 0; | const TensorLayout& p2) = 0; | ||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -463,7 +487,7 @@ public: | |||||
using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
const TensorLayout& p1, | const TensorLayout& p1, | ||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
const TensorLayout& p3) { | const TensorLayout& p3) { | ||||
@@ -474,6 +498,17 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3) { | |||||
std::vector<AlgorithmInfo> ret; | |||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | |||||
ret.emplace_back(algo->info()); | |||||
} | |||||
return ret; | |||||
} | |||||
/** | /** | ||||
* \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
* algorithm by heuristic. | * algorithm by heuristic. | ||||
@@ -500,6 +535,9 @@ protected: | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, const TensorLayout& p3) = 0; | const TensorLayout& p2, const TensorLayout& p3) = 0; | ||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -537,6 +575,18 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3, | |||||
const TensorLayout& p4) { | |||||
std::vector<AlgorithmInfo> ret; | |||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | |||||
ret.emplace_back(algo->info()); | |||||
} | |||||
return ret; | |||||
} | |||||
/** | /** | ||||
* \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
* algorithm by heuristic. | * algorithm by heuristic. | ||||
@@ -562,7 +612,11 @@ protected: | |||||
~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
const TensorLayout& p4) = 0; | const TensorLayout& p4) = 0; | ||||
@@ -604,6 +658,18 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) { | |||||
std::vector<AlgorithmInfo> ret; | |||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||||
ret.emplace_back(algo->info()); | |||||
} | |||||
return ret; | |||||
} | |||||
/** | /** | ||||
* \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
* algorithm by heuristic. | * algorithm by heuristic. | ||||
@@ -629,7 +695,12 @@ protected: | |||||
~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
const TensorLayout& p4, const TensorLayout& p5, | const TensorLayout& p4, const TensorLayout& p5, | ||||
@@ -172,9 +172,14 @@ std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | |||||
ret.push_back(i); | ret.push_back(i); | ||||
} | } | ||||
} | } | ||||
megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) { | |||||
auto ret_safe = get_all_algorithms(src,dst); | |||||
megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm"); | |||||
return ret_safe; | |||||
} | |||||
Algorithm* PoolingImpl::get_algorithm_heuristic( | Algorithm* PoolingImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -131,6 +131,8 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -100,10 +100,16 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||||
ret.push_back(i); | ret.push_back(i); | ||||
} | } | ||||
} | } | ||||
megdnn_assert(!ret.empty(), "no algorithm for %s", | |||||
args.to_string().c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
template <class Opr> | |||||
std::vector<typename Opr::Algorithm*> get_all_algorithms_safe( | |||||
const typename Opr::AlgoBase::SizeArgs& args) { | |||||
auto ret_safe = get_all_algorithms<Opr>(args); | |||||
megdnn_assert(!ret_safe.empty(), "no algorithm for %s", | |||||
args.to_string().c_str()); | |||||
return ret_safe; | |||||
} | |||||
/*! | /*! | ||||
* \brief a helper function to get an algorithm match attribute. If require a | * \brief a helper function to get an algorithm match attribute. If require a | ||||
@@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | ||||
return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); | return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); | ||||
} | } | ||||
std::vector<BatchConvBiasForwardImpl::Algorithm*> | |||||
BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& bias, | |||||
const TensorLayout& z, | |||||
const TensorLayout& dst) { | |||||
AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | |||||
return megdnn::get_all_algorithms_safe<BatchConvBiasForwardImpl>(args); | |||||
} | |||||
size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -42,6 +42,10 @@ protected: | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
@@ -51,6 +51,12 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms_safe( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
auto ret_safe = get_all_algorithms(A,B,C); | |||||
megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm"); | |||||
return ret_safe; | |||||
} | |||||
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
@@ -45,6 +45,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, filter, bias, z, dst}); | {this, src, filter, bias, z, dst}); | ||||
} | } | ||||
std::vector<ConvBiasForward::Algorithm*> | |||||
ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& bias, | |||||
const TensorLayout& z, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<ConvBiasForwardImpl>( | |||||
{this, src, filter, bias, z, dst}); | |||||
} | |||||
ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
@@ -84,6 +84,10 @@ public: | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
@@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); | return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); | ||||
} | } | ||||
std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
AlgoBase::SizeArgs args{this, src, filter, dst}; | |||||
return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>(args); | |||||
} | |||||
size_t ConvolutionForwardImpl::get_workspace_in_bytes( | size_t ConvolutionForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, | const TensorLayout& dst, | ||||
@@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
{this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
} | } | ||||
std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||||
{this, filter, diff, grad}); | |||||
} | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
@@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, diff, grad}); | {this, src, diff, grad}); | ||||
} | } | ||||
std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||||
ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||||
{this, src, diff, grad}); | |||||
} | |||||
ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
@@ -59,6 +59,10 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -111,6 +115,10 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -159,6 +167,10 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, filter, dst}); | {this, src, filter, dst}); | ||||
} | } | ||||
std::vector<Convolution3DForwardImpl::Algorithm*> | |||||
Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<Convolution3DForwardImpl>( | |||||
{this, src, filter, dst}); | |||||
} | |||||
size_t Convolution3DForwardImpl::get_workspace_in_bytes( | size_t Convolution3DForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
@@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
{this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
} | } | ||||
std::vector<Convolution3DBackwardDataImpl::Algorithm*> | |||||
Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<Convolution3DBackwardDataImpl>( | |||||
{this, filter, diff, grad}); | |||||
} | |||||
Convolution3DBackwardDataImpl::Algorithm* | Convolution3DBackwardDataImpl::Algorithm* | ||||
Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
@@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, diff, grad}); | {this, src, diff, grad}); | ||||
} | } | ||||
std::vector<Convolution3DBackwardFilterImpl::Algorithm*> | |||||
Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<Convolution3DBackwardFilterImpl>( | |||||
{this, src, diff, grad}); | |||||
} | |||||
Convolution3DBackwardFilterImpl::Algorithm* | Convolution3DBackwardFilterImpl::Algorithm* | ||||
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
@@ -39,6 +39,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -72,6 +75,9 @@ public: | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
@@ -109,6 +115,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -51,6 +51,15 @@ std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */, | |||||
return algos; | return algos; | ||||
} | } | ||||
std::vector<AlgoFwd*> Fwd::get_all_algorithms_safe(const TensorLayout& im, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& offset, | |||||
const TensorLayout& mask, | |||||
const TensorLayout& dst) { | |||||
auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst); | |||||
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm"); | |||||
return ret_safe; | |||||
} | |||||
AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
@@ -115,6 +124,14 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */ | |||||
return algos; | return algos; | ||||
} | } | ||||
std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms_safe(const TensorLayout& im, | |||||
const TensorLayout& offset, const TensorLayout& mask, | |||||
const TensorLayout& out_grad, const TensorLayout& filter_grad) { | |||||
auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad); | |||||
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm"); | |||||
return ret_safe; | |||||
} | |||||
AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
@@ -181,6 +198,14 @@ std::vector<AlgoBwdData*> BwdData::get_all_algorithms( | |||||
algos.push_back(static_cast<AlgoBwdData*>(i)); | algos.push_back(static_cast<AlgoBwdData*>(i)); | ||||
return algos; | return algos; | ||||
} | } | ||||
std::vector<AlgoBwdData*> BwdData::get_all_algorithms_safe( | |||||
const TensorLayout& im, const TensorLayout& filter, | |||||
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, | |||||
const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) { | |||||
auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad); | |||||
megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm"); | |||||
return ret_safe; | |||||
} | |||||
AlgoBwdData* BwdData::get_algorithm_heuristic( | AlgoBwdData* BwdData::get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
@@ -54,6 +54,10 @@ protected: | |||||
const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& im, const TensorLayout& filter, | |||||
const TensorLayout& offset, const TensorLayout& mask, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
@@ -105,6 +109,10 @@ protected: | |||||
const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
const TensorLayout& filter_grad) override; | const TensorLayout& filter_grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& im, const TensorLayout& offset, | |||||
const TensorLayout& mask, const TensorLayout& out_grad, | |||||
const TensorLayout& filter_grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
@@ -161,6 +169,13 @@ protected: | |||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
const TensorLayout& offset_grad, | const TensorLayout& offset_grad, | ||||
const TensorLayout& mask_grad) override; | const TensorLayout& mask_grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& im, const TensorLayout& filter, | |||||
const TensorLayout& offset, const TensorLayout& mask, | |||||
const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
const TensorLayout& offset_grad, | |||||
const TensorLayout& mask_grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
@@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic( | |||||
Algorithm::attribute_str(positive_attr).c_str(), | Algorithm::attribute_str(positive_attr).c_str(), | ||||
args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
} | } | ||||
std::vector<LocalShareForwardImpl::Algorithm*> | std::vector<LocalShareForwardImpl::Algorithm*> | ||||
LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
@@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); | return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); | ||||
} | } | ||||
std::vector<LocalShareForwardImpl::Algorithm*> | |||||
LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
AlgoBase::SizeArgs args{this, src, filter, dst}; | |||||
return megdnn::get_all_algorithms_safe<LocalShareForwardImpl>(args); | |||||
} | |||||
size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
const TensorLayout& filter, | const TensorLayout& filter, | ||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
@@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); | return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); | ||||
} | } | ||||
std::vector<LocalShareBackwardDataImpl::Algorithm*> | |||||
LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
AlgoBase::SizeArgs args{this, filter, diff, grad}; | |||||
return megdnn::get_all_algorithms_safe<LocalShareBackwardDataImpl>(args); | |||||
} | |||||
size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, | size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
@@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); | return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); | ||||
} | } | ||||
std::vector<LocalShareBackwardFilterImpl::Algorithm*> | |||||
LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
AlgoBase::SizeArgs args{this, src, diff, grad}; | |||||
return megdnn::get_all_algorithms_safe<LocalShareBackwardFilterImpl>(args); | |||||
} | |||||
size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
const TensorLayout& diff, | const TensorLayout& diff, | ||||
const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
@@ -37,6 +37,9 @@ public: | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
@@ -72,6 +75,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -105,6 +111,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | ||||
} | } | ||||
std::vector<MatrixMulForwardImpl::Algorithm*> | |||||
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | |||||
return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||||
} | |||||
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -60,6 +60,10 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | ||||
} | } | ||||
std::vector<PoolingForwardImpl::Algorithm*> | |||||
PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||||
} | |||||
PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, dst, diff, grad}); | {this, src, dst, diff, grad}); | ||||
} | } | ||||
std::vector<PoolingBackwardImpl::Algorithm*> | |||||
PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||||
{this, src, dst, diff, grad}); | |||||
} | |||||
PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
@@ -55,6 +55,8 @@ public: | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -99,6 +101,9 @@ protected: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
@@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | ||||
} | } | ||||
std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | |||||
return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||||
} | |||||
BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
@@ -35,6 +35,9 @@ private: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
@@ -279,11 +279,18 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms( | |||||
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
if (ret.empty()) { | if (ret.empty()) { | ||||
return naive::ConvBiasForwardImpl::get_all_algorithms(src, filter, bias, | |||||
return naive::ConvBiasForwardImpl::get_all_algorithms_safe(src, filter, bias, | |||||
z, dst); | z, dst); | ||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) { | |||||
auto ret_safe = ConvBiasImpl::get_all_algorithms(src,filter,bias,z,dst); | |||||
return ret_safe; | |||||
} | |||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -87,6 +87,10 @@ public: | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) override; | |||||
//! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
@@ -198,12 +198,19 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( | |||||
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | ||||
auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
if (ret.empty()) { | if (ret.empty()) { | ||||
return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter, | |||||
return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, | |||||
dst); | dst); | ||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst); | |||||
return ret_safe; | |||||
} | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
} | } | ||||
auto fparam = make_ncb_kern_size_param(filter, diff, grad); | auto fparam = make_ncb_kern_size_param(filter, diff, grad); | ||||
auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
megdnn_assert(!ret.empty(), "no usable conv fwd algorithm"); | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad); | |||||
megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm"); | |||||
return ret_safe; | |||||
} | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
@@ -85,6 +85,10 @@ public: | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
//! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -326,6 +330,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -96,6 +96,13 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||||
return gemv_algos; | return gemv_algos; | ||||
} | } | ||||
std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms_safe( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
auto gemv_algos_safe = get_all_algorithms(A,B,C); | |||||
megdnn_assert(!gemv_algos_safe.empty(), "no usable MatrixMul fwd algorithm"); | |||||
return gemv_algos_safe; | |||||
} | |||||
MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | ||||
const AlgorithmDesc& desc) { | const AlgorithmDesc& desc) { | ||||
if (!desc.valid()) { | if (!desc.valid()) { | ||||
@@ -270,6 +270,10 @@ protected: | |||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
->default_batch_conv_bias_fwd_algo()}; | ->default_batch_conv_bias_fwd_algo()}; | ||||
} | } | ||||
std::vector<BatchConvBiasForward::Algorithm*> | |||||
BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_batch_conv_bias_fwd_algo()}; | |||||
} | |||||
BatchConvBiasForward::Algorithm* | BatchConvBiasForward::Algorithm* | ||||
BatchConvBiasForwardImpl::get_algorithm_heuristic( | BatchConvBiasForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
@@ -30,6 +30,11 @@ public: | |||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
@@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||||
} | } | ||||
} | } | ||||
std::vector<BatchedMatrixMulForward::Algorithm*> | std::vector<BatchedMatrixMulForward::Algorithm*> | ||||
BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | ||||
const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
@@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
return {static_cast<HandleImpl*>(handle()) | return {static_cast<HandleImpl*>(handle()) | ||||
->default_batched_matmul_fwd_algo()}; | ->default_batched_matmul_fwd_algo()}; | ||||
} | } | ||||
std::vector<BatchedMatrixMulForward::Algorithm*> | |||||
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||||
const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_batched_matmul_fwd_algo()}; | |||||
} | |||||
BatchedMatrixMulForward::Algorithm* | BatchedMatrixMulForward::Algorithm* | ||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
@@ -27,6 +27,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
@@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | ||||
} | } | ||||
std::vector<ConvBiasForward::Algorithm*> | |||||
ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | |||||
} | |||||
ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
@@ -31,6 +31,11 @@ public: | |||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& bias, const TensorLayout& z, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
@@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, | |||||
return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | ||||
} | } | ||||
std::vector<ConvolutionForward::Algorithm *> | |||||
ConvolutionForwardImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | |||||
} | |||||
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
@@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||||
return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | ||||
} | } | ||||
std::vector<ConvolutionBackwardData::Algorithm *> | |||||
ConvolutionBackwardDataImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | |||||
} | |||||
ConvolutionBackwardData::Algorithm* | ConvolutionBackwardData::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
@@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||||
return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | ||||
} | } | ||||
std::vector<ConvolutionBackwardFilter::Algorithm *> | |||||
ConvolutionBackwardFilterImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
const TensorLayout &, const TensorLayout &) | |||||
{ | |||||
return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | |||||
} | |||||
ConvolutionBackwardFilter::Algorithm* | ConvolutionBackwardFilter::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
@@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
const TensorLayout &filter, | const TensorLayout &filter, | ||||
const TensorLayout &dst) override; | const TensorLayout &dst) override; | ||||
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||||
const TensorLayout &filter, | |||||
const TensorLayout &dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | ||||
const TensorLayout &diff, | const TensorLayout &diff, | ||||
const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &filter, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
const TensorLayout &diff, | const TensorLayout &diff, | ||||
const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
} | } | ||||
std::vector<Convolution3DForward::Algorithm*> | std::vector<Convolution3DForward::Algorithm*> | ||||
Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | ||||
const TensorLayout&, | const TensorLayout&, | ||||
const TensorLayout&) { | const TensorLayout&) { | ||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | ||||
} | } | ||||
std::vector<Convolution3DForward::Algorithm*> | |||||
Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||||
} | |||||
Convolution3DForward::Algorithm* | Convolution3DForward::Algorithm* | ||||
Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
@@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | ||||
} | } | ||||
std::vector<Convolution3DBackwardData::Algorithm*> | |||||
Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||||
} | |||||
Convolution3DBackwardData::Algorithm* | Convolution3DBackwardData::Algorithm* | ||||
Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
@@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
->default_conv3d_bwd_filter_algo()}; | ->default_conv3d_bwd_filter_algo()}; | ||||
} | } | ||||
std::vector<Convolution3DBackwardFilter::Algorithm*> | |||||
Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_conv3d_bwd_filter_algo()}; | |||||
} | |||||
Convolution3DBackwardFilter::Algorithm* | Convolution3DBackwardFilter::Algorithm* | ||||
Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
@@ -22,6 +22,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -44,6 +47,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -66,6 +72,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -25,6 +25,12 @@ public: | |||||
const TensorLayout& /* dst */) override { | const TensorLayout& /* dst */) override { | ||||
return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
}; | }; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||||
const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
const TensorLayout& /* dst */) override { | |||||
return std::vector<Algorithm*>(); | |||||
}; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
@@ -67,6 +73,13 @@ public: | |||||
const TensorLayout& /* filter_grad */) override { | const TensorLayout& /* filter_grad */) override { | ||||
return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
}; | }; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /* im */, const TensorLayout& /* offset */, | |||||
const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, | |||||
const TensorLayout& /* filter_grad */) override { | |||||
return std::vector<Algorithm*>(); | |||||
}; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /* im */, const TensorLayout& /* offset */, | const TensorLayout& /* im */, const TensorLayout& /* offset */, | ||||
@@ -112,6 +125,16 @@ public: | |||||
return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
}; | }; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||||
const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
const TensorLayout& /* out_grad */, | |||||
const TensorLayout& /* im_grad */, | |||||
const TensorLayout& /* offset_grad */, | |||||
const TensorLayout& /* mask_grad */) override { | |||||
return std::vector<Algorithm*>(); | |||||
}; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /* im */, const TensorLayout& /* filter */, | const TensorLayout& /* im */, const TensorLayout& /* filter */, | ||||
const TensorLayout& /* offset */, const TensorLayout& /* mask */, | const TensorLayout& /* offset */, const TensorLayout& /* mask */, | ||||
@@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | ||||
} | } | ||||
std::vector<LocalShareForward::Algorithm*> | |||||
LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||||
} | |||||
LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
@@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
->default_local_share_bwd_data_algo()}; | ->default_local_share_bwd_data_algo()}; | ||||
} | } | ||||
std::vector<LocalShareBackwardData::Algorithm*> | |||||
LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_data_algo()}; | |||||
} | |||||
LocalShareBackwardData::Algorithm* | LocalShareBackwardData::Algorithm* | ||||
LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
@@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
->default_local_share_bwd_filter_algo()}; | ->default_local_share_bwd_filter_algo()}; | ||||
} | } | ||||
std::vector<LocalShareBackwardFilter::Algorithm*> | |||||
LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle()) | |||||
->default_local_share_bwd_filter_algo()}; | |||||
} | |||||
LocalShareBackwardFilter::Algorithm* | LocalShareBackwardFilter::Algorithm* | ||||
LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
@@ -30,6 +30,10 @@ public: | |||||
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*dst*/) override; | const TensorLayout& /*dst*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||||
const TensorLayout& /*dst*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -55,6 +59,10 @@ public: | |||||
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||||
const TensorLayout& /*grad*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -75,11 +83,14 @@ public: | |||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||||
const TensorLayout& /*grad*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | ||||
} | } | ||||
std::vector<MatrixMulForward::Algorithm*> | |||||
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||||
const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) { | |||||
return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||||
} | |||||
MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -29,6 +29,10 @@ public: | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -603,6 +603,10 @@ std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms( | |||||
const TensorLayout&, const TensorLayout&) { | const TensorLayout&, const TensorLayout&) { | ||||
return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | ||||
} | } | ||||
std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms_safe( | |||||
const TensorLayout&, const TensorLayout&) { | |||||
return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | |||||
} | |||||
Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | ||||
@@ -626,6 +630,11 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | ||||
return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | ||||
} | } | ||||
std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||||
const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | |||||
const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | |||||
return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | |||||
} | |||||
Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | ||||
@@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward { | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -60,6 +62,9 @@ public: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | ||||
} | } | ||||
std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | |||||
return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||||
} | |||||
BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
@@ -35,6 +35,9 @@ private: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
@@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, filter, dst}); | {this, src, filter, dst}); | ||||
} | } | ||||
std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& filter, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>( | |||||
{this, src, filter, dst}); | |||||
} | |||||
size_t ConvolutionForwardImpl::get_workspace_in_bytes( | size_t ConvolutionForwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, const PreprocessedFilter*) { | const TensorLayout& dst, const PreprocessedFilter*) { | ||||
@@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
{this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
} | } | ||||
std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||||
{this, filter, diff, grad}); | |||||
} | |||||
ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
@@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
{this, src, diff, grad}); | {this, src, diff, grad}); | ||||
} | } | ||||
std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||||
ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||||
{this, src, diff, grad}); | |||||
} | |||||
ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
@@ -74,6 +74,9 @@ private: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& filter, | |||||
const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
@@ -123,6 +126,9 @@ private: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& filter, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -172,6 +178,9 @@ private: | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& diff, | |||||
const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
@@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | ||||
} | } | ||||
std::vector<MatrixMulForwardImpl::Algorithm*> | |||||
MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) { | |||||
AlgoBase::SizeArgs args{this, A, B, C}; | |||||
return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||||
} | |||||
MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -36,6 +36,10 @@ private: | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
@@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
const char* PoolingForwardImpl::get_algorithm_set_name() const { | const char* PoolingForwardImpl::get_algorithm_set_name() const { | ||||
return "ROCM_POOLING_FORWARD"; | return "ROCM_POOLING_FORWARD"; | ||||
} | } | ||||
std::vector<PoolingForwardImpl::Algorithm*> | std::vector<PoolingForwardImpl::Algorithm*> | ||||
PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | ||||
} | } | ||||
std::vector<PoolingForwardImpl::Algorithm*> | |||||
PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||||
} | |||||
PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -82,6 +86,13 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
{this, src, dst, diff, grad}); | {this, src, dst, diff, grad}); | ||||
} | } | ||||
std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) { | |||||
return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||||
{this, src, dst, diff, grad}); | |||||
} | |||||
Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
@@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward { | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward { | |||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& diff, const TensorLayout& grad) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
@@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
return fallback_worksapce; | return fallback_worksapce; | ||||
} | } | ||||
} | } | ||||
std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); | ||||
} | } | ||||
std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) { | |||||
return megdnn::get_all_algorithms_safe<PoolingImpl>({this, src, dst}); | |||||
} | |||||
Algorithm* PoolingImpl::get_algorithm_heuristic( | Algorithm* PoolingImpl::get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
@@ -63,6 +63,8 @@ public: | |||||
protected: | protected: | ||||
std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
std::vector<Algorithm*> get_all_algorithms_safe( | |||||
const TensorLayout& src, const TensorLayout& dst) override; | |||||
Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
@@ -164,7 +164,7 @@ public: | |||||
} | } | ||||
std::vector<Algorithm::Info::Desc> ret; | std::vector<Algorithm::Info::Desc> ret; | ||||
megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | ||||
auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||||
auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||||
opr, layouts); | opr, layouts); | ||||
for (auto algo_info : vec) { | for (auto algo_info : vec) { | ||||
if (!(algo_info.attribute & | if (!(algo_info.attribute & | ||||
@@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts, | |||||
auto opr = benchmark.opr(); | auto opr = benchmark.opr(); | ||||
opr->param() = benchmark.param(); | opr->param() = benchmark.param(); | ||||
proxy.deduce_layout(opr, layouts); | proxy.deduce_layout(opr, layouts); | ||||
auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info(opr, layouts); | |||||
auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info_safe(opr, layouts); | |||||
float min_used = std::numeric_limits<float>::max(); | float min_used = std::numeric_limits<float>::max(); | ||||
bool execed = false; | bool execed = false; | ||||
for (auto i : algos) { | for (auto i : algos) { | ||||
@@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName { | |||||
* \brief a callable to check that given algorithm is used for heuristic | * \brief a callable to check that given algorithm is used for heuristic | ||||
* \param require_algo if its value is true, then requires | * \param require_algo if its value is true, then requires | ||||
* get_algorithm_heuristic() to return the expected algo; otherwise the | * get_algorithm_heuristic() to return the expected algo; otherwise the | ||||
* expected algo must exist in get_all_algorithms() and it would be set to | |||||
* expected algo must exist in get_all_algorithms_safe() and it would be set to | |||||
* be used | * be used | ||||
*/ | */ | ||||
template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | ||||
@@ -536,7 +536,7 @@ public: | |||||
opr->param() = | opr->param() = | ||||
Algorithm::deserialize_read_pod<typename Opr::Param>(param); | Algorithm::deserialize_read_pod<typename Opr::Param>(param); | ||||
for (auto algo_info : | for (auto algo_info : | ||||
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||||
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||||
opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
algo_info.desc.name, | algo_info.desc.name, | ||||
@@ -695,7 +695,7 @@ Checker<Convolution> checker(handle); | |||||
float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | ||||
UniformFloatRNG rng(scale, 2 * scale); | UniformFloatRNG rng(scale, 2 * scale); | ||||
checker.set_rng(0, &rng).set_rng(1, &rng); | checker.set_rng(0, &rng).set_rng(1, &rng); | ||||
for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) { | |||||
for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) { | |||||
used_algos.insert(algo.desc); | used_algos.insert(algo.desc); | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
@@ -720,7 +720,7 @@ Checker<Convolution> checker(handle); | |||||
opr->param() = param; | opr->param() = param; | ||||
std::string param_str; | std::string param_str; | ||||
Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) { | |||||
for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) { | |||||
used_algos_bwd_data.insert(algo.desc); | used_algos_bwd_data.insert(algo.desc); | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
@@ -747,7 +747,7 @@ Checker<Convolution> checker(handle); | |||||
opr->param() = param; | opr->param() = param; | ||||
std::string param_str; | std::string param_str; | ||||
Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) { | |||||
for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) { | |||||
used_algos_bwd_flt.insert(algo.desc); | used_algos_bwd_flt.insert(algo.desc); | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
@@ -25,9 +25,9 @@ struct AlgoProxy; | |||||
template <typename Opr> \ | template <typename Opr> \ | ||||
struct AlgoProxy<Opr, arity> { \ | struct AlgoProxy<Opr, arity> { \ | ||||
static std::vector<typename Opr::AlgorithmInfo> \ | static std::vector<typename Opr::AlgorithmInfo> \ | ||||
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
megdnn_assert(layouts.size() == arity); \ | megdnn_assert(layouts.size() == arity); \ | ||||
return opr->get_all_algorithms_info(LAYOUTS); \ | |||||
return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||||
} \ | } \ | ||||
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | ||||
Opr* opr, const TensorLayoutArray& layouts) { \ | Opr* opr, const TensorLayoutArray& layouts) { \ | ||||
@@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8); | |||||
template <> \ | template <> \ | ||||
struct AlgoProxy<Opr, arity> { \ | struct AlgoProxy<Opr, arity> { \ | ||||
static std::vector<typename Opr::AlgorithmInfo> \ | static std::vector<typename Opr::AlgorithmInfo> \ | ||||
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
megdnn_assert(layouts.size() == arity); \ | megdnn_assert(layouts.size() == arity); \ | ||||
return opr->get_all_algorithms_info(LAYOUTS); \ | |||||
return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||||
} \ | } \ | ||||
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | ||||
Opr* opr, const TensorLayoutArray& layouts) { \ | Opr* opr, const TensorLayoutArray& layouts) { \ | ||||
@@ -288,7 +288,7 @@ struct OprProxyProfilingBase | |||||
Algorithm::deserialize_read_pod<typename Opr::Param>(param); | Algorithm::deserialize_read_pod<typename Opr::Param>(param); | ||||
std::vector<Algorithm::SearchItem> ret; | std::vector<Algorithm::SearchItem> ret; | ||||
for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||||
for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||||
opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); | Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); | ||||
std::vector<Algorithm::SearchItem>&& sub_items = | std::vector<Algorithm::SearchItem>&& sub_items = | ||||
@@ -367,7 +367,7 @@ struct OprProxyProfilingBase | |||||
megdnn_log("Find best algo %s in cache", algo->name()); | megdnn_log("Find best algo %s in cache", algo->name()); | ||||
return; | return; | ||||
} | } | ||||
for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||||
for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||||
opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
//! construct execution_policy | //! construct execution_policy | ||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
@@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> { | |||||
if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { | if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { | ||||
size_t min_time = std::numeric_limits<size_t>::max(); | size_t min_time = std::numeric_limits<size_t>::max(); | ||||
for (auto algo : | for (auto algo : | ||||
AlgoProxy<Opr, arity>::get_all_algorithms_info(opr, layouts)) { | |||||
AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) { | |||||
opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
auto preprocess_tensors = | auto preprocess_tensors = | ||||
@@ -84,7 +84,7 @@ void test_multibatchsize( | |||||
auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); | auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); | ||||
{ | { | ||||
opr_reference->execution_policy().algo.reset(); | opr_reference->execution_policy().algo.reset(); | ||||
for (auto i : opr_reference->get_all_algorithms_info( | |||||
for (auto i : opr_reference->get_all_algorithms_info_safe( | |||||
A_tensor.layout(), B_tensor.layout(), | A_tensor.layout(), B_tensor.layout(), | ||||
C_tensor.layout())) { | C_tensor.layout())) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
@@ -113,7 +113,7 @@ void test_multibatchsize( | |||||
{{}, {}, C_tensor_prime.tensornd_host()}); | {{}, {}, C_tensor_prime.tensornd_host()}); | ||||
{ | { | ||||
opr_reference->execution_policy().algo.reset(); | opr_reference->execution_policy().algo.reset(); | ||||
for (auto i : opr_reference->get_all_algorithms_info( | |||||
for (auto i : opr_reference->get_all_algorithms_info_safe( | |||||
A_tensor_prime.layout(), B_tensor.layout(), | A_tensor_prime.layout(), B_tensor.layout(), | ||||
C_tensor_batch.layout())) { | C_tensor_batch.layout())) { | ||||
if (std::regex_match( | if (std::regex_match( | ||||
@@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo( | |||||
return {}; | return {}; | ||||
} | } | ||||
} | } | ||||
for (auto&& algo : dnn_op->get_all_algorithms_info( | |||||
for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||||
std::forward<Args>(args)...)) { | std::forward<Args>(args)...)) { | ||||
dnn_op->execution_policy().algo = algo.desc; | dnn_op->execution_policy().algo = algo.desc; | ||||
auto layouts = dnn_op->deduce_preprocessed_filter_layout( | auto layouts = dnn_op->deduce_preprocessed_filter_layout( | ||||
@@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo( | |||||
return {}; | return {}; | ||||
} | } | ||||
} | } | ||||
for (auto&& algo : dnn_op->get_all_algorithms_info( | |||||
for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||||
std::forward<Args>(args)...)) { | std::forward<Args>(args)...)) { | ||||
dnn_op->execution_policy().algo = algo.desc; | dnn_op->execution_policy().algo = algo.desc; | ||||
auto layouts = dnn_op->deduce_preprocessed_filter_layout( | auto layouts = dnn_op->deduce_preprocessed_filter_layout( | ||||
@@ -805,7 +805,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> | |||||
AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | ||||
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | ||||
auto heu = choose_by_heuristic(m_execution_policy.strategy); | auto heu = choose_by_heuristic(m_execution_policy.strategy); | ||||
auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), | |||||
auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info_safe(args...), | |||||
m_fastrun_layouts); | m_fastrun_layouts); | ||||
bool found = false; | bool found = false; | ||||
for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
@@ -2473,6 +2473,11 @@ public: | |||||
std::vector<AlgorithmInfo>(const TensorLayout& p0, | std::vector<AlgorithmInfo>(const TensorLayout& p0, | ||||
const TensorLayout& p1, | const TensorLayout& p1, | ||||
const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
MOCK_METHOD3(get_all_algorithms_info_safe, | |||||
std::vector<AlgorithmInfo>(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2)); | |||||
MOCK_METHOD6(get_algorithm_info_heuristic, | MOCK_METHOD6(get_algorithm_info_heuristic, | ||||
AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||
@@ -2484,6 +2489,11 @@ public: | |||||
std::vector<Algorithm*>(const TensorLayout& p0, | std::vector<Algorithm*>(const TensorLayout& p0, | ||||
const TensorLayout& p1, | const TensorLayout& p1, | ||||
const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
MOCK_METHOD3(get_all_algorithms_safe, | |||||
std::vector<Algorithm*>(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2)); | |||||
MOCK_METHOD6(get_algorithm_heuristic, | MOCK_METHOD6(get_algorithm_heuristic, | ||||
Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | ||||
const TensorLayout& p2, | const TensorLayout& p2, | ||||