diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index bb50fdbd..6349d501 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -315,7 +315,7 @@ public: /*! * \brief get a string representation for current algorithm set; * - * get_all_algorithms() may return different algorithms only if + * get_all_algorithms_safe() may return different algorithms only if * algorithm set name differs. This is used for checking cache * validity. */ @@ -354,6 +354,15 @@ public: return ret; } + std::vector get_all_algorithms_info_safe(const TensorLayout& p0, + const TensorLayout& p1) { + std::vector ret; + for (auto&& algo : get_all_algorithms_safe(p0, p1)) { + ret.emplace_back(algo->info()); + } + return ret; + } + /** * \brief Returns the best algorithm information which indicate the * algorithm by heuristic. @@ -378,6 +387,8 @@ protected: //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( const TensorLayout& p0, const TensorLayout& p1) = 0; + virtual std::vector get_all_algorithms_safe( + const TensorLayout& p0, const TensorLayout& p1) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -412,6 +423,16 @@ public: return ret; } + std::vector get_all_algorithms_info_safe(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2) { + std::vector ret; + for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { + ret.emplace_back(algo->info()); + } + return ret; + } + /** * \brief Returns the best algorithm information which indicate the * algorithm by heuristic. @@ -438,6 +459,9 @@ protected: virtual std::vector get_all_algorithms( const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; + virtual std::vector get_all_algorithms_safe( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -463,7 +487,7 @@ public: using AlgoAttribute = detail::Algorithm::Attribute; //! get all possible algorithm decriptions for the specified layouts - std::vector get_all_algorithms_info(const TensorLayout& p0, + std::vector get_all_algorithms_info(const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, const TensorLayout& p3) { @@ -474,6 +498,17 @@ public: return ret; } + std::vector get_all_algorithms_info_safe(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2, + const TensorLayout& p3) { + std::vector ret; + for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { + ret.emplace_back(algo->info()); + } + return ret; + } + /** * \brief Returns the best algorithm information which indicate the * algorithm by heuristic. @@ -500,6 +535,9 @@ protected: virtual std::vector get_all_algorithms( const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, const TensorLayout& p3) = 0; + virtual std::vector get_all_algorithms_safe( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -537,6 +575,18 @@ public: return ret; } + std::vector get_all_algorithms_info_safe(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2, + const TensorLayout& p3, + const TensorLayout& p4) { + std::vector ret; + for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { + ret.emplace_back(algo->info()); + } + return ret; + } + /** * \brief Returns the best algorithm information which indicate the * algorithm by heuristic. @@ -562,7 +612,11 @@ protected: ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts - virtual std::vector get_all_algorithms( + virtual std::vector get_all_algorithms( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4) = 0; + virtual std::vector get_all_algorithms_safe( const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p4) = 0; @@ -604,6 +658,18 @@ public: return ret; } + std::vector get_all_algorithms_info_safe( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p6, const TensorLayout& p7) { + std::vector ret; + for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { + ret.emplace_back(algo->info()); + } + return ret; + } + /** * \brief Returns the best algorithm information which indicate the * algorithm by heuristic. @@ -629,7 +695,12 @@ protected: ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts - virtual std::vector get_all_algorithms( + virtual std::vector get_all_algorithms( + const TensorLayout& p0, const TensorLayout& p1, + const TensorLayout& p2, const TensorLayout& p3, + const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p6, const TensorLayout& p7) = 0; + virtual std::vector get_all_algorithms_safe( const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 6e5b62f8..3806c491 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -172,9 +172,14 @@ std::vector PoolingImpl::get_all_algorithms( ret.push_back(i); } } - megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); return ret; } +std::vector PoolingImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) { + auto ret_safe = get_all_algorithms(src,dst); + megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm"); + return ret_safe; +} Algorithm* PoolingImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 04ab72e5..1f25f9c0 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -131,6 +131,8 @@ public: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, diff --git a/dnn/src/common/algo_chooser.h b/dnn/src/common/algo_chooser.h index 9c0964d5..12c3e6e2 100644 --- a/dnn/src/common/algo_chooser.h +++ b/dnn/src/common/algo_chooser.h @@ -100,10 +100,16 @@ std::vector get_all_algorithms( ret.push_back(i); } } - megdnn_assert(!ret.empty(), "no algorithm for %s", - args.to_string().c_str()); return ret; } +template +std::vector get_all_algorithms_safe( + const typename Opr::AlgoBase::SizeArgs& args) { + auto ret_safe = get_all_algorithms(args); + megdnn_assert(!ret_safe.empty(), "no algorithm for %s", + args.to_string().c_str()); + return ret_safe; +} /*! * \brief a helper function to get an algorithm match attribute. If require a diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.cpp b/dnn/src/cuda/batch_conv_bias/opr_impl.cpp index 9d3b3711..5e5ff3af 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.cpp @@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; return megdnn::get_all_algorithms(args); } +std::vector +BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& bias, + const TensorLayout& z, + const TensorLayout& dst) { + AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; + return megdnn::get_all_algorithms_safe(args); +} size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.h b/dnn/src/cuda/batch_conv_bias/opr_impl.h index 2114e2ef..29ab7a30 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.h +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.h @@ -42,6 +42,10 @@ protected: const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp index 7bc03766..01e902f3 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp @@ -51,6 +51,12 @@ std::vector BatchedMatrixMulForwardImpl::get_all_algorithms( } return ret; } +std::vector BatchedMatrixMulForwardImpl::get_all_algorithms_safe( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { + auto ret_safe = get_all_algorithms(A,B,C); + megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm"); + return ret_safe; +} Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.h b/dnn/src/cuda/batched_matrix_mul/opr_impl.h index ea3690de..cef99e8f 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.h +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.h @@ -45,6 +45,9 @@ protected: std::vector get_all_algorithms(const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) override; + std::vector get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) override; Algorithm* get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 2a00f1f6..1d9df4e3 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, {this, src, filter, bias, z, dst}); } +std::vector +ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& bias, + const TensorLayout& z, + const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe( + {this, src, filter, bias, z, dst}); +} + ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index d6ea710e..06ee9a95 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -84,6 +84,10 @@ public: const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index 560da586..633ba6dc 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, return megdnn::get_all_algorithms(args); } +std::vector +ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { + AlgoBase::SizeArgs args{this, src, filter, dst}; + return megdnn::get_all_algorithms_safe(args); +} + size_t ConvolutionForwardImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, @@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, {this, filter, diff, grad}); } +std::vector +ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, filter, diff, grad}); +} + ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, @@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, {this, src, diff, grad}); } +std::vector +ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, src, diff, grad}); +} + ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 08bc4e75..8579b498 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -59,6 +59,10 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; + + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -111,6 +115,10 @@ protected: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; + + std::vector get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, @@ -159,6 +167,10 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; + + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/cuda/convolution3d/opr_impl.cpp b/dnn/src/cuda/convolution3d/opr_impl.cpp index 607edf55..f384e428 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.cpp +++ b/dnn/src/cuda/convolution3d/opr_impl.cpp @@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src, {this, src, filter, dst}); } +std::vector +Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe( + {this, src, filter, dst}); +} + size_t Convolution3DForwardImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { @@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, {this, filter, diff, grad}); } +std::vector +Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, filter, diff, grad}); +} + Convolution3DBackwardDataImpl::Algorithm* Convolution3DBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, @@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, {this, src, diff, grad}); } +std::vector +Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, src, diff, grad}); +} + Convolution3DBackwardFilterImpl::Algorithm* Convolution3DBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, diff --git a/dnn/src/cuda/convolution3d/opr_impl.h b/dnn/src/cuda/convolution3d/opr_impl.h index f240ae5d..5b208b4b 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.h +++ b/dnn/src/cuda/convolution3d/opr_impl.h @@ -39,6 +39,9 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -72,6 +75,9 @@ public: protected: std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( @@ -109,6 +115,9 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/cuda/deformable_conv/opr_impl.cpp b/dnn/src/cuda/deformable_conv/opr_impl.cpp index 167d66dc..909f1e05 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.cpp +++ b/dnn/src/cuda/deformable_conv/opr_impl.cpp @@ -51,6 +51,15 @@ std::vector Fwd::get_all_algorithms(const TensorLayout& /* im */, return algos; } +std::vector Fwd::get_all_algorithms_safe(const TensorLayout& im, + const TensorLayout& filter, + const TensorLayout& offset, + const TensorLayout& mask, + const TensorLayout& dst) { + auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst); + megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm"); + return ret_safe; +} AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, const TensorLayout& filter, @@ -115,6 +124,14 @@ std::vector BwdFlt::get_all_algorithms(const TensorLayout& /* im */ return algos; } +std::vector BwdFlt::get_all_algorithms_safe(const TensorLayout& im, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& filter_grad) { + auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad); + megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm"); + return ret_safe; +} + AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, @@ -181,6 +198,14 @@ std::vector BwdData::get_all_algorithms( algos.push_back(static_cast(i)); return algos; } +std::vector BwdData::get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) { + auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad); + megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm"); + return ret_safe; +} AlgoBwdData* BwdData::get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& filter, diff --git a/dnn/src/cuda/deformable_conv/opr_impl.h b/dnn/src/cuda/deformable_conv/opr_impl.h index a56eaaf7..29b0589b 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.h +++ b/dnn/src/cuda/deformable_conv/opr_impl.h @@ -54,6 +54,10 @@ protected: const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& filter, @@ -105,6 +109,10 @@ protected: const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& filter_grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& offset, @@ -161,6 +169,13 @@ protected: const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; + + std::vector get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& im_grad, + const TensorLayout& offset_grad, + const TensorLayout& mask_grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& filter, diff --git a/dnn/src/cuda/local_share/opr_impl.cpp b/dnn/src/cuda/local_share/opr_impl.cpp index 786991e2..c43b6c1b 100644 --- a/dnn/src/cuda/local_share/opr_impl.cpp +++ b/dnn/src/cuda/local_share/opr_impl.cpp @@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic( Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), workspace_limit_in_bytes)); } - std::vector LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, const TensorLayout& filter, @@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, return megdnn::get_all_algorithms(args); } +std::vector +LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { + AlgoBase::SizeArgs args{this, src, filter, dst}; + return megdnn::get_all_algorithms_safe(args); +} + size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { @@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, return megdnn::get_all_algorithms(args); } +std::vector +LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + AlgoBase::SizeArgs args{this, filter, diff, grad}; + return megdnn::get_all_algorithms_safe(args); +} + size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) { @@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, return megdnn::get_all_algorithms(args); } +std::vector +LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) { + AlgoBase::SizeArgs args{this, src, diff, grad}; + return megdnn::get_all_algorithms_safe(args); +} + size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { diff --git a/dnn/src/cuda/local_share/opr_impl.h b/dnn/src/cuda/local_share/opr_impl.h index 07800fda..a261bed4 100644 --- a/dnn/src/cuda/local_share/opr_impl.h +++ b/dnn/src/cuda/local_share/opr_impl.h @@ -37,6 +37,9 @@ public: protected: std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( @@ -72,6 +75,9 @@ protected: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, @@ -105,6 +111,9 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/cuda/matrix_mul/opr_impl.cpp b/dnn/src/cuda/matrix_mul/opr_impl.cpp index 059998df..b47b31b5 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.cpp +++ b/dnn/src/cuda/matrix_mul/opr_impl.cpp @@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, return megdnn::get_all_algorithms(args); } +std::vector +MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms_safe(args); +} + MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index d4de4fc6..3f9024fa 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -60,6 +60,10 @@ protected: std::vector get_all_algorithms(const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) override; + + std::vector get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) override; Algorithm* get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index 88183a0e..c4f3d39d 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, const TensorLayout& dst) { return megdnn::get_all_algorithms({this, src, dst}); } +std::vector +PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe({this, src, dst}); +} PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, @@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src, {this, src, dst, diff, grad}); } +std::vector +PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, src, dst, diff, grad}); +} + PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad, diff --git a/dnn/src/cuda/pooling/opr_impl.h b/dnn/src/cuda/pooling/opr_impl.h index a8c3e65f..3096290d 100644 --- a/dnn/src/cuda/pooling/opr_impl.h +++ b/dnn/src/cuda/pooling/opr_impl.h @@ -55,6 +55,8 @@ public: protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, @@ -99,6 +101,9 @@ protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad, diff --git a/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp b/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp index ff32abe3..1d6f028a 100644 --- a/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/batched_matrix_mul/opr_impl.cpp @@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, AlgoBase::SizeArgs args{this, A, B, C}; return megdnn::get_all_algorithms(args); } +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms_safe(args); +} BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( diff --git a/dnn/src/fallback/batched_matrix_mul/opr_impl.h b/dnn/src/fallback/batched_matrix_mul/opr_impl.h index d9fbcda5..bda805be 100644 --- a/dnn/src/fallback/batched_matrix_mul/opr_impl.h +++ b/dnn/src/fallback/batched_matrix_mul/opr_impl.h @@ -35,6 +35,9 @@ private: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index fd41e8bf..17f006c2 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -279,11 +279,18 @@ std::vector ConvBiasImpl::get_all_algorithms( auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); auto ret = get_all_algorithms_with_ncb(fparam); if (ret.empty()) { - return naive::ConvBiasForwardImpl::get_all_algorithms(src, filter, bias, + return naive::ConvBiasForwardImpl::get_all_algorithms_safe(src, filter, bias, z, dst); } return ret; } +std::vector ConvBiasImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) { + auto ret_safe = ConvBiasImpl::get_all_algorithms(src,filter,bias,z,dst); + return ret_safe; +} ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 0e9abf60..fe5393e8 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -87,6 +87,10 @@ public: const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; //! implemented by get_algorithm_heuristic_with_ncb() Algorithm* get_algorithm_heuristic( diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 9551e3ea..36ac473b 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -198,12 +198,19 @@ std::vector ConvolutionImpl::get_all_algorithms( auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); auto ret = get_all_algorithms_with_ncb(fparam); if (ret.empty()) { - return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter, + return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, dst); } return ret; } +std::vector ConvolutionImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { + auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst); + return ret_safe; +} + ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, } auto fparam = make_ncb_kern_size_param(filter, diff, grad); auto ret = get_all_algorithms_with_ncb(fparam); - megdnn_assert(!ret.empty(), "no usable conv fwd algorithm"); return ret; } +std::vector +ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + + auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad); + megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm"); + return ret_safe; +} + ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index 234671f1..9e0287d0 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -85,6 +85,10 @@ public: const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + //! implemented by get_algorithm_heuristic_with_ncb() Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, @@ -326,6 +330,9 @@ public: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 82ef1591..706f1658 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -96,6 +96,13 @@ std::vector MatrixMulImpl::get_all_algorithms( return gemv_algos; } +std::vector MatrixMulImpl::get_all_algorithms_safe( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { + auto gemv_algos_safe = get_all_algorithms(A,B,C); + megdnn_assert(!gemv_algos_safe.empty(), "no usable MatrixMul fwd algorithm"); + return gemv_algos_safe; +} + MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( const AlgorithmDesc& desc) { if (!desc.valid()) { diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index b8a9a689..d13aede4 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -270,6 +270,10 @@ protected: const TensorLayout& B, const TensorLayout& C) override; + std::vector get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/naive/batch_conv_bias/opr_impl.cpp b/dnn/src/naive/batch_conv_bias/opr_impl.cpp index 26d1bdc8..dd19fb0c 100644 --- a/dnn/src/naive/batch_conv_bias/opr_impl.cpp +++ b/dnn/src/naive/batch_conv_bias/opr_impl.cpp @@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, ->default_batch_conv_bias_fwd_algo()}; } +std::vector +BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_batch_conv_bias_fwd_algo()}; +} + BatchConvBiasForward::Algorithm* BatchConvBiasForwardImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* filter */, diff --git a/dnn/src/naive/batch_conv_bias/opr_impl.h b/dnn/src/naive/batch_conv_bias/opr_impl.h index 06b2e7a4..b449a722 100644 --- a/dnn/src/naive/batch_conv_bias/opr_impl.h +++ b/dnn/src/naive/batch_conv_bias/opr_impl.h @@ -30,6 +30,11 @@ public: const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp index 01e06a1d..3bcda189 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp @@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, } } - std::vector BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, const TensorLayout& /*B*/, @@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, return {static_cast(handle()) ->default_batched_matmul_fwd_algo()}; } +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/) { + return {static_cast(handle()) + ->default_batched_matmul_fwd_algo()}; +} BatchedMatrixMulForward::Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.h b/dnn/src/naive/batched_matrix_mul/opr_impl.h index 36433cf0..7755ad7d 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.h +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.h @@ -27,6 +27,9 @@ public: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, diff --git a/dnn/src/naive/conv_bias/opr_impl.cpp b/dnn/src/naive/conv_bias/opr_impl.cpp index 8c31744a..0d450e1b 100644 --- a/dnn/src/naive/conv_bias/opr_impl.cpp +++ b/dnn/src/naive/conv_bias/opr_impl.cpp @@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, return {static_cast(handle())->default_conv_bias_fwd_algo()}; } +std::vector +ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_conv_bias_fwd_algo()}; +} + ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* bias */, const TensorLayout& /* z */, diff --git a/dnn/src/naive/conv_bias/opr_impl.h b/dnn/src/naive/conv_bias/opr_impl.h index f875f0cf..4624d5c7 100644 --- a/dnn/src/naive/conv_bias/opr_impl.h +++ b/dnn/src/naive/conv_bias/opr_impl.h @@ -31,6 +31,11 @@ public: const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, diff --git a/dnn/src/naive/convolution/convolution.cpp b/dnn/src/naive/convolution/convolution.cpp index 658e908e..e1699ae2 100644 --- a/dnn/src/naive/convolution/convolution.cpp +++ b/dnn/src/naive/convolution/convolution.cpp @@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, return {static_cast(handle())->default_conv_fwd_algo()}; } +std::vector +ConvolutionForwardImpl:: get_all_algorithms_safe(const TensorLayout &, + const TensorLayout &, const TensorLayout &) +{ + return {static_cast(handle())->default_conv_fwd_algo()}; +} + ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* filter */, const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, @@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, return {static_cast(handle())->default_conv_bwd_data_algo()}; } +std::vector +ConvolutionBackwardDataImpl:: get_all_algorithms_safe(const TensorLayout &, + const TensorLayout &, const TensorLayout &) +{ + return {static_cast(handle())->default_conv_bwd_data_algo()}; +} + ConvolutionBackwardData::Algorithm* ConvolutionBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& /* filter */, const TensorLayout& /* diff */, @@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, return {static_cast(handle())->default_conv_bwd_filter_algo()}; } +std::vector +ConvolutionBackwardFilterImpl:: get_all_algorithms_safe(const TensorLayout &, + const TensorLayout &, const TensorLayout &) +{ + return {static_cast(handle())->default_conv_bwd_filter_algo()}; +} + ConvolutionBackwardFilter::Algorithm* ConvolutionBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* diff */, diff --git a/dnn/src/naive/convolution/opr_impl.h b/dnn/src/naive/convolution/opr_impl.h index ddf1a196..53af52ee 100644 --- a/dnn/src/naive/convolution/opr_impl.h +++ b/dnn/src/naive/convolution/opr_impl.h @@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { std::vector get_all_algorithms(const TensorLayout &src, const TensorLayout &filter, const TensorLayout &dst) override; + std::vector get_all_algorithms_safe(const TensorLayout &src, + const TensorLayout &filter, + const TensorLayout &dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { std::vector get_all_algorithms(const TensorLayout &filter, const TensorLayout &diff, const TensorLayout &grad) override; + std::vector get_all_algorithms_safe(const TensorLayout &filter, + const TensorLayout &diff, + const TensorLayout &grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, @@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { std::vector get_all_algorithms(const TensorLayout &src, const TensorLayout &diff, const TensorLayout &grad) override; + std::vector get_all_algorithms_safe(const TensorLayout &src, + const TensorLayout &diff, + const TensorLayout &grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/naive/convolution3d/convolution3d.cpp b/dnn/src/naive/convolution3d/convolution3d.cpp index 2aa08ebe..cd42120f 100644 --- a/dnn/src/naive/convolution3d/convolution3d.cpp +++ b/dnn/src/naive/convolution3d/convolution3d.cpp @@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, megdnn_assert_internal(0); } - std::vector Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, const TensorLayout&, const TensorLayout&) { return {static_cast(handle())->default_conv3d_fwd_algo()}; } +std::vector +Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_conv3d_fwd_algo()}; +} Convolution3DForward::Algorithm* Convolution3DForwardImpl::get_algorithm_heuristic( @@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, return {static_cast(handle())->default_conv3d_bwd_data_algo()}; } +std::vector +Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_conv3d_bwd_data_algo()}; +} + Convolution3DBackwardData::Algorithm* Convolution3DBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& /* filter */, const TensorLayout& /* diff */, @@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, ->default_conv3d_bwd_filter_algo()}; } +std::vector +Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_conv3d_bwd_filter_algo()}; +} + Convolution3DBackwardFilter::Algorithm* Convolution3DBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* diff */, diff --git a/dnn/src/naive/convolution3d/opr_impl.h b/dnn/src/naive/convolution3d/opr_impl.h index 708271f9..9deca831 100644 --- a/dnn/src/naive/convolution3d/opr_impl.h +++ b/dnn/src/naive/convolution3d/opr_impl.h @@ -22,6 +22,9 @@ public: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -44,6 +47,9 @@ public: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, @@ -66,6 +72,9 @@ public: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/naive/deformable_conv/opr_impl.h b/dnn/src/naive/deformable_conv/opr_impl.h index 7b9bd0a5..949b4a77 100644 --- a/dnn/src/naive/deformable_conv/opr_impl.h +++ b/dnn/src/naive/deformable_conv/opr_impl.h @@ -25,6 +25,12 @@ public: const TensorLayout& /* dst */) override { return std::vector(); }; + std::vector get_all_algorithms_safe( + const TensorLayout& /* im */, const TensorLayout& /* filter */, + const TensorLayout& /* offset */, const TensorLayout& /* mask */, + const TensorLayout& /* dst */) override { + return std::vector(); + }; Algorithm* get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* filter */, @@ -67,6 +73,13 @@ public: const TensorLayout& /* filter_grad */) override { return std::vector(); }; + + std::vector get_all_algorithms_safe( + const TensorLayout& /* im */, const TensorLayout& /* offset */, + const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, + const TensorLayout& /* filter_grad */) override { + return std::vector(); + }; Algorithm* get_algorithm_heuristic( const TensorLayout& /* im */, const TensorLayout& /* offset */, @@ -112,6 +125,16 @@ public: return std::vector(); }; + std::vector get_all_algorithms_safe( + const TensorLayout& /* im */, const TensorLayout& /* filter */, + const TensorLayout& /* offset */, const TensorLayout& /* mask */, + const TensorLayout& /* out_grad */, + const TensorLayout& /* im_grad */, + const TensorLayout& /* offset_grad */, + const TensorLayout& /* mask_grad */) override { + return std::vector(); + }; + Algorithm* get_algorithm_heuristic( const TensorLayout& /* im */, const TensorLayout& /* filter */, const TensorLayout& /* offset */, const TensorLayout& /* mask */, diff --git a/dnn/src/naive/local_share/opr_impl.cpp b/dnn/src/naive/local_share/opr_impl.cpp index 75f7e7fb..518bf676 100644 --- a/dnn/src/naive/local_share/opr_impl.cpp +++ b/dnn/src/naive/local_share/opr_impl.cpp @@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, return {static_cast(handle())->default_local_share_fwd_algo()}; } +std::vector +LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_local_share_fwd_algo()}; +} + LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* diff */, const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, @@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, ->default_local_share_bwd_data_algo()}; } +std::vector +LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_local_share_bwd_data_algo()}; +} + LocalShareBackwardData::Algorithm* LocalShareBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& /* filter */, const TensorLayout& /* diff */, @@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, ->default_local_share_bwd_filter_algo()}; } +std::vector +LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_local_share_bwd_filter_algo()}; +} + LocalShareBackwardFilter::Algorithm* LocalShareBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& /* src */, const TensorLayout& /* diff */, diff --git a/dnn/src/naive/local_share/opr_impl.h b/dnn/src/naive/local_share/opr_impl.h index cea5b5c5..4da11449 100644 --- a/dnn/src/naive/local_share/opr_impl.h +++ b/dnn/src/naive/local_share/opr_impl.h @@ -30,6 +30,10 @@ public: const TensorLayout& /*src*/, const TensorLayout& /*filter*/, const TensorLayout& /*dst*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*src*/, const TensorLayout& /*filter*/, + const TensorLayout& /*dst*/) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& /*src*/, const TensorLayout& /*filter*/, const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, @@ -55,6 +59,10 @@ public: const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, + const TensorLayout& /*grad*/) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, @@ -75,11 +83,14 @@ public: const TensorLayout&) override { return 0; } - std::vector get_all_algorithms( const TensorLayout& /*src*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*src*/, const TensorLayout& /*diff*/, + const TensorLayout& /*grad*/) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& /*src*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index 65623dfd..505a08ac 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, return {static_cast(handle())->default_matmul_fwd_algo()}; } +std::vector +MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/) { + return {static_cast(handle())->default_matmul_fwd_algo()}; +} + MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, diff --git a/dnn/src/naive/matrix_mul/opr_impl.h b/dnn/src/naive/matrix_mul/opr_impl.h index 022a96ab..2e9c7c72 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.h +++ b/dnn/src/naive/matrix_mul/opr_impl.h @@ -29,6 +29,10 @@ public: const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, diff --git a/dnn/src/naive/pooling/opr_impl.cpp b/dnn/src/naive/pooling/opr_impl.cpp index 2c8dd35d..ed3c11e4 100644 --- a/dnn/src/naive/pooling/opr_impl.cpp +++ b/dnn/src/naive/pooling/opr_impl.cpp @@ -603,6 +603,10 @@ std::vector PoolingForwardImpl::get_all_algorithms( const TensorLayout&, const TensorLayout&) { return {static_cast(handle())->default_pooling_fwd_algo()}; } +std::vector PoolingForwardImpl::get_all_algorithms_safe( + const TensorLayout&, const TensorLayout&) { + return {static_cast(handle())->default_pooling_fwd_algo()}; +} Algorithm* PoolingForwardImpl::get_algorithm_heuristic( const TensorLayout& /*src*/, const TensorLayout& /*dst*/, @@ -626,6 +630,11 @@ std::vector PoolingBackwardImpl::get_all_algorithms( const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { return {static_cast(handle())->default_pooling_bwd_algo()}; } +std::vector PoolingBackwardImpl::get_all_algorithms_safe( + const TensorLayout& /*src*/, const TensorLayout& /*dst*/, + const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { + return {static_cast(handle())->default_pooling_bwd_algo()}; +} Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( const TensorLayout& /*src*/, const TensorLayout& /*dst*/, diff --git a/dnn/src/naive/pooling/opr_impl.h b/dnn/src/naive/pooling/opr_impl.h index fe34fbf7..323efcce 100644 --- a/dnn/src/naive/pooling/opr_impl.h +++ b/dnn/src/naive/pooling/opr_impl.h @@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward { std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, @@ -60,6 +62,9 @@ public: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp index b1a7c16c..b5ba31e7 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp @@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, return megdnn::get_all_algorithms(args); } +std::vector +BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms_safe(args); +} + BatchedMatrixMulForwardImpl::Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.h b/dnn/src/rocm/batched_matrix_mul/opr_impl.h index 24f74b82..7d61db24 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.h +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.h @@ -35,6 +35,9 @@ private: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; Algorithm* get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, diff --git a/dnn/src/rocm/convolution/opr_impl.cpp b/dnn/src/rocm/convolution/opr_impl.cpp index b7d79039..fdde9047 100644 --- a/dnn/src/rocm/convolution/opr_impl.cpp +++ b/dnn/src/rocm/convolution/opr_impl.cpp @@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, {this, src, filter, dst}); } +std::vector +ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe( + {this, src, filter, dst}); +} + size_t ConvolutionForwardImpl::get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, const PreprocessedFilter*) { @@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, {this, filter, diff, grad}); } +std::vector +ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, filter, diff, grad}); +} + ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, @@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, {this, src, diff, grad}); } +std::vector +ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, src, diff, grad}); +} + ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, diff --git a/dnn/src/rocm/convolution/opr_impl.h b/dnn/src/rocm/convolution/opr_impl.h index 8aaca2c8..93281e74 100644 --- a/dnn/src/rocm/convolution/opr_impl.h +++ b/dnn/src/rocm/convolution/opr_impl.h @@ -74,6 +74,9 @@ private: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, size_t workspace_limit_in_bytes, @@ -123,6 +126,9 @@ private: std::vector get_all_algorithms( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, @@ -172,6 +178,9 @@ private: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, diff --git a/dnn/src/rocm/matrix_mul/opr_impl.cpp b/dnn/src/rocm/matrix_mul/opr_impl.cpp index e423d538..caa01c14 100644 --- a/dnn/src/rocm/matrix_mul/opr_impl.cpp +++ b/dnn/src/rocm/matrix_mul/opr_impl.cpp @@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, return megdnn::get_all_algorithms(args); } +std::vector +MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, + const TensorLayout& B, + const TensorLayout& C) { + AlgoBase::SizeArgs args{this, A, B, C}; + return megdnn::get_all_algorithms_safe(args); +} + MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/rocm/matrix_mul/opr_impl.h b/dnn/src/rocm/matrix_mul/opr_impl.h index 2e0e7394..c06ac3ce 100644 --- a/dnn/src/rocm/matrix_mul/opr_impl.h +++ b/dnn/src/rocm/matrix_mul/opr_impl.h @@ -36,6 +36,10 @@ private: const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/) override; + std::vector get_all_algorithms_safe( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override; + Algorithm* get_algorithm_heuristic( const TensorLayout& /*A*/, const TensorLayout& /*B*/, const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, diff --git a/dnn/src/rocm/pooling/opr_impl.cpp b/dnn/src/rocm/pooling/opr_impl.cpp index 074a1dbc..879cf3c0 100644 --- a/dnn/src/rocm/pooling/opr_impl.cpp +++ b/dnn/src/rocm/pooling/opr_impl.cpp @@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, const char* PoolingForwardImpl::get_algorithm_set_name() const { return "ROCM_POOLING_FORWARD"; } - std::vector PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, const TensorLayout& dst) { return megdnn::get_all_algorithms({this, src, dst}); } +std::vector +PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, + const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe({this, src, dst}); +} PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, @@ -82,6 +86,13 @@ std::vector PoolingBackwardImpl::get_all_algorithms( {this, src, dst, diff, grad}); } +std::vector PoolingBackwardImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { + return megdnn::get_all_algorithms_safe( + {this, src, dst, diff, grad}); +} + Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad, diff --git a/dnn/src/rocm/pooling/opr_impl.h b/dnn/src/rocm/pooling/opr_impl.h index 57be4c20..796a41f8 100644 --- a/dnn/src/rocm/pooling/opr_impl.h +++ b/dnn/src/rocm/pooling/opr_impl.h @@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward { protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, @@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward { std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& grad, diff --git a/dnn/src/x86/pooling/opr_impl.cpp b/dnn/src/x86/pooling/opr_impl.cpp index 6178999b..3e14b637 100644 --- a/dnn/src/x86/pooling/opr_impl.cpp +++ b/dnn/src/x86/pooling/opr_impl.cpp @@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, return fallback_worksapce; } } - std::vector PoolingImpl::get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) { return megdnn::get_all_algorithms({this, src, dst}); } +std::vector PoolingImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) { + return megdnn::get_all_algorithms_safe({this, src, dst}); +} Algorithm* PoolingImpl::get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, diff --git a/dnn/src/x86/pooling/opr_impl.h b/dnn/src/x86/pooling/opr_impl.h index 6d0a43c2..cfb15b88 100644 --- a/dnn/src/x86/pooling/opr_impl.h +++ b/dnn/src/x86/pooling/opr_impl.h @@ -63,6 +63,8 @@ public: protected: std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& dst) override; + std::vector get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& dst, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/test/common/accuracy_shake_checker.h b/dnn/test/common/accuracy_shake_checker.h index b5429187..840cb2c7 100644 --- a/dnn/test/common/accuracy_shake_checker.h +++ b/dnn/test/common/accuracy_shake_checker.h @@ -164,7 +164,7 @@ public: } std::vector ret; megdnn_assert(layouts.size() == OprTrait::arity); - auto vec = AlgoProxy::arity>::get_all_algorithms_info( + auto vec = AlgoProxy::arity>::get_all_algorithms_info_safe( opr, layouts); for (auto algo_info : vec) { if (!(algo_info.attribute & diff --git a/dnn/test/common/benchmarker.h b/dnn/test/common/benchmarker.h index 3a706526..9c3f2639 100644 --- a/dnn/test/common/benchmarker.h +++ b/dnn/test/common/benchmarker.h @@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker& benchmark, TensorLayoutArray layouts, auto opr = benchmark.opr(); opr->param() = benchmark.param(); proxy.deduce_layout(opr, layouts); - auto algos = OprAlgoProxy::get_all_algorithms_info(opr, layouts); + auto algos = OprAlgoProxy::get_all_algorithms_info_safe(opr, layouts); float min_used = std::numeric_limits::max(); bool execed = false; for (auto i : algos) { diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index 48337100..f7413135 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName { * \brief a callable to check that given algorithm is used for heuristic * \param require_algo if its value is true, then requires * get_algorithm_heuristic() to return the expected algo; otherwise the - * expected algo must exist in get_all_algorithms() and it would be set to + * expected algo must exist in get_all_algorithms_safe() and it would be set to * be used */ template > @@ -536,7 +536,7 @@ public: opr->param() = Algorithm::deserialize_read_pod(param); for (auto algo_info : - AlgoProxy::arity>::get_all_algorithms_info( + AlgoProxy::arity>::get_all_algorithms_info_safe( opr.get(), layouts)) { if (std::regex_match( algo_info.desc.name, diff --git a/dnn/test/common/convolution.cpp b/dnn/test/common/convolution.cpp index 8b228112..29739961 100644 --- a/dnn/test/common/convolution.cpp +++ b/dnn/test/common/convolution.cpp @@ -695,7 +695,7 @@ Checker checker(handle); float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); UniformFloatRNG rng(scale, 2 * scale); checker.set_rng(0, &rng).set_rng(1, &rng); - for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) { + for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) { used_algos.insert(algo.desc); opr->execution_policy().algo = algo.desc; @@ -720,7 +720,7 @@ Checker checker(handle); opr->param() = param; std::string param_str; Algorithm::serialize_write_pod(opr->param(), param_str); - for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) { + for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) { used_algos_bwd_data.insert(algo.desc); opr->execution_policy().algo = algo.desc; construct_sub_execution_policy_heuristic< @@ -747,7 +747,7 @@ Checker checker(handle); opr->param() = param; std::string param_str; Algorithm::serialize_write_pod(opr->param(), param_str); - for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) { + for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) { used_algos_bwd_flt.insert(algo.desc); opr->execution_policy().algo = algo.desc; construct_sub_execution_policy_heuristic< diff --git a/dnn/test/common/opr_algo_proxy.h b/dnn/test/common/opr_algo_proxy.h index 3ee9f746..962905e3 100644 --- a/dnn/test/common/opr_algo_proxy.h +++ b/dnn/test/common/opr_algo_proxy.h @@ -25,9 +25,9 @@ struct AlgoProxy; template \ struct AlgoProxy { \ static std::vector \ - get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ + get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ megdnn_assert(layouts.size() == arity); \ - return opr->get_all_algorithms_info(LAYOUTS); \ + return opr->get_all_algorithms_info_safe(LAYOUTS); \ } \ static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ Opr* opr, const TensorLayoutArray& layouts) { \ @@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8); template <> \ struct AlgoProxy { \ static std::vector \ - get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ + get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ megdnn_assert(layouts.size() == arity); \ - return opr->get_all_algorithms_info(LAYOUTS); \ + return opr->get_all_algorithms_info_safe(LAYOUTS); \ } \ static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ Opr* opr, const TensorLayoutArray& layouts) { \ diff --git a/dnn/test/common/opr_proxy.h b/dnn/test/common/opr_proxy.h index 81fc5dbe..c822a373 100644 --- a/dnn/test/common/opr_proxy.h +++ b/dnn/test/common/opr_proxy.h @@ -288,7 +288,7 @@ struct OprProxyProfilingBase Algorithm::deserialize_read_pod(param); std::vector ret; - for (auto algo_info : AlgoProxy::get_all_algorithms_info( + for (auto algo_info : AlgoProxy::get_all_algorithms_info_safe( opr.get(), layouts)) { Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); std::vector&& sub_items = @@ -367,7 +367,7 @@ struct OprProxyProfilingBase megdnn_log("Find best algo %s in cache", algo->name()); return; } - for (auto algo : AlgoProxy::get_all_algorithms_info( + for (auto algo : AlgoProxy::get_all_algorithms_info_safe( opr.get(), layouts)) { //! construct execution_policy opr->execution_policy().algo = algo.desc; @@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase { if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { size_t min_time = std::numeric_limits::max(); for (auto algo : - AlgoProxy::get_all_algorithms_info(opr, layouts)) { + AlgoProxy::get_all_algorithms_info_safe(opr, layouts)) { opr->execution_policy().algo = algo.desc; auto preprocess_tensors = diff --git a/dnn/test/cuda/cutlass_matmul.cpp b/dnn/test/cuda/cutlass_matmul.cpp index 913a6ec6..3a41e2aa 100644 --- a/dnn/test/cuda/cutlass_matmul.cpp +++ b/dnn/test/cuda/cutlass_matmul.cpp @@ -84,7 +84,7 @@ void test_multibatchsize( auto opr_reference = handle_cuda->create_operator(); { opr_reference->execution_policy().algo.reset(); - for (auto i : opr_reference->get_all_algorithms_info( + for (auto i : opr_reference->get_all_algorithms_info_safe( A_tensor.layout(), B_tensor.layout(), C_tensor.layout())) { if (std::regex_match( @@ -113,7 +113,7 @@ void test_multibatchsize( {{}, {}, C_tensor_prime.tensornd_host()}); { opr_reference->execution_policy().algo.reset(); - for (auto i : opr_reference->get_all_algorithms_info( + for (auto i : opr_reference->get_all_algorithms_info_safe( A_tensor_prime.layout(), B_tensor.layout(), C_tensor_batch.layout())) { if (std::regex_match( diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index e0e0810f..a4cbb4ad 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo( return {}; } } - for (auto&& algo : dnn_op->get_all_algorithms_info( + for (auto&& algo : dnn_op->get_all_algorithms_info_safe( std::forward(args)...)) { dnn_op->execution_policy().algo = algo.desc; auto layouts = dnn_op->deduce_preprocessed_filter_layout( @@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo( return {}; } } - for (auto&& algo : dnn_op->get_all_algorithms_info( + for (auto&& algo : dnn_op->get_all_algorithms_info_safe( std::forward(args)...)) { dnn_op->execution_policy().algo = algo.desc; auto layouts = dnn_op->deduce_preprocessed_filter_layout( diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index ffefee25..8f1306eb 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -805,7 +805,7 @@ std::vector::ImplAlgo> AlgoChooser::AlgoChooserHelper::get_all_candidates() const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) auto heu = choose_by_heuristic(m_execution_policy.strategy); - auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), + auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info_safe(args...), m_fastrun_layouts); bool found = false; for (size_t i = 0; i < ret.size(); ++i) { diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 1b7c8ac3..009c3d8e 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -2473,6 +2473,11 @@ public: std::vector(const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2)); + + MOCK_METHOD3(get_all_algorithms_info_safe, + std::vector(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2)); MOCK_METHOD6(get_algorithm_info_heuristic, AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, @@ -2484,6 +2489,11 @@ public: std::vector(const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2)); + + MOCK_METHOD3(get_all_algorithms_safe, + std::vector(const TensorLayout& p0, + const TensorLayout& p1, + const TensorLayout& p2)); MOCK_METHOD6(get_algorithm_heuristic, Algorithm*(const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,