GitOrigin-RevId: 6610516650
tags/v1.6.0-rc1
@@ -24,5 +24,24 @@ | |||||
#define setenv(name,value,overwrite) _putenv_s(name,value) | #define setenv(name,value,overwrite) _putenv_s(name,value) | ||||
#endif | #endif | ||||
namespace megdnn { | |||||
/*! | |||||
* \brief whether there is an algorithm from algo_pack() that is available for | |||||
* current size | |||||
*/ | |||||
template <class Opr, typename... Args> | |||||
bool has_available_algo(Opr* opr, Args&&... args) { | |||||
const typename Opr::AlgoBase::SizeArgs size_args( | |||||
opr, std::forward<Args>(args)...); | |||||
for (auto i : Opr::algo_pack().all_algos) { | |||||
if (i->is_available(size_args)) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -16,6 +16,7 @@ | |||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "megdnn/common.h" | |||||
#include "utils.h" | #include "utils.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -75,21 +76,6 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||||
} | } | ||||
/*! | /*! | ||||
* \brief whether there is an algorithm from algo_pack() that is available for | |||||
* current size | |||||
*/ | |||||
template <class Opr> | |||||
bool has_available_algo( | |||||
const typename Opr::AlgoBase::SizeArgs& args) { | |||||
for (auto i : Opr::algo_pack().all_algos) { | |||||
if (i->is_available(args)) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
/*! | |||||
* \brief a helper function to get an algorithm match attribute. If require a | * \brief a helper function to get an algorithm match attribute. If require a | ||||
* algorithm with specified attribute, and the given algorithm match that | * algorithm with specified attribute, and the given algorithm match that | ||||
* attribute, return the given algorithm. Otherwise return nullptr | * attribute, return the given algorithm. Otherwise return nullptr | ||||
@@ -134,13 +134,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<ConvBiasForwardImpl*>(config.second.get()), | |||||
config.first[0], | |||||
config.first[1], | |||||
config.first[2], | |||||
config.first[3], | |||||
config.first[4]}; | |||||
bool is_relayout_ok = true; | bool is_relayout_ok = true; | ||||
if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) { | if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) { | ||||
is_relayout_ok = relayout_format::RelayoutFormatFast::usable( | is_relayout_ok = relayout_format::RelayoutFormatFast::usable( | ||||
@@ -148,7 +141,11 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||||
RelayoutFormat::Param::Mode::NCHW4_NCHW); | RelayoutFormat::Param::Mode::NCHW4_NCHW); | ||||
} | } | ||||
return is_relayout_ok && has_available_algo<ConvBiasForwardImpl>(sub_args); | |||||
return is_relayout_ok && | |||||
has_available_algo<ConvBiasForwardImpl>( | |||||
static_cast<ConvBiasForwardImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2], | |||||
config.first[3], config.first[4]); | |||||
} | } | ||||
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( | WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( | ||||
@@ -107,16 +107,12 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( | |||||
auto conv_args = args; | auto conv_args = args; | ||||
conv_args.dst_layout = &dst_layout; | conv_args.dst_layout = &dst_layout; | ||||
auto config = prepare_sub_opr(conv_args); | auto config = prepare_sub_opr(conv_args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
bool ret = has_available_algo<ConvBiasForwardImpl>( | |||||
static_cast<ConvBiasForwardImpl*>(config.second.get()), | static_cast<ConvBiasForwardImpl*>(config.second.get()), | ||||
config.first[0], | |||||
config.first[1], | |||||
config.first[2], | |||||
config.first[3], | |||||
config.first[4]}; | |||||
bool ret = has_available_algo<ConvBiasForwardImpl>(sub_args); | |||||
return ret; | |||||
config.first[0], config.first[1], config.first[2], config.first[3], | |||||
config.first[4]); | |||||
return ret; | |||||
} | } | ||||
WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( | WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( | ||||
@@ -82,11 +82,10 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( | |||||
} | } | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]}; | |||||
return has_available_algo<ConvolutionBackwardDataImpl>(sub_args); | |||||
return has_available_algo<ConvolutionBackwardDataImpl>( | |||||
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
WorkspaceBundle | WorkspaceBundle | ||||
@@ -78,11 +78,10 @@ bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | |||||
} | } | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]}; | |||||
return has_available_algo<ConvolutionBackwardFilterImpl>(sub_args); | |||||
return has_available_algo<ConvolutionBackwardFilterImpl>( | |||||
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
WorkspaceBundle | WorkspaceBundle | ||||
@@ -73,11 +73,10 @@ bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available( | |||||
} | } | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]}; | |||||
return has_available_algo<Convolution3DBackwardDataImpl>(sub_args); | |||||
return has_available_algo<Convolution3DBackwardDataImpl>( | |||||
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
WorkspaceBundle | WorkspaceBundle | ||||
@@ -77,11 +77,10 @@ bool Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | |||||
} | } | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]}; | |||||
return has_available_algo<Convolution3DBackwardFilterImpl>(sub_args); | |||||
return has_available_algo<Convolution3DBackwardFilterImpl>( | |||||
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
WorkspaceBundle | WorkspaceBundle | ||||
@@ -80,11 +80,10 @@ bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available( | |||||
} | } | ||||
auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
AlgoBase::SizeArgs sub_args{ | |||||
static_cast<Convolution3DForwardImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]}; | |||||
return has_available_algo<Convolution3DForwardImpl>(sub_args); | |||||
return has_available_algo<Convolution3DForwardImpl>( | |||||
static_cast<Convolution3DForwardImpl*>(config.second.get()), | |||||
config.first[0], config.first[1], config.first[2]); | |||||
} | } | ||||
WorkspaceBundle | WorkspaceBundle | ||||