Browse Source

refactor(mgb): refactor has-usable-algo function for global optimizer

GitOrigin-RevId: 6610516650
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
d195fdec71
9 changed files with 45 additions and 52 deletions
  1. +19
    -0
      dnn/include/megdnn/common.h
  2. +1
    -15
      dnn/src/common/algo_chooser.h
  3. +5
    -8
      dnn/src/cuda/conv_bias/conv_nchwqs8.cpp
  4. +5
    -9
      dnn/src/cuda/conv_bias/group_conv.cpp
  5. +3
    -4
      dnn/src/cuda/convolution/backward_data/group_conv.cpp
  6. +3
    -4
      dnn/src/cuda/convolution/backward_filter/group_conv.cpp
  7. +3
    -4
      dnn/src/cuda/convolution3d/backward_data/group_conv.cpp
  8. +3
    -4
      dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp
  9. +3
    -4
      dnn/src/cuda/convolution3d/forward/group_conv.cpp

+ 19
- 0
dnn/include/megdnn/common.h View File

@@ -24,5 +24,24 @@
#define setenv(name,value,overwrite) _putenv_s(name,value) #define setenv(name,value,overwrite) _putenv_s(name,value)
#endif #endif


namespace megdnn {

/*!
* \brief whether there is an algorithm from algo_pack() that is available for
* current size
*/
template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args(
opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) {
return true;
}
}
return false;
}

}


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 15
dnn/src/common/algo_chooser.h View File

@@ -16,6 +16,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>


#include "megdnn/common.h"
#include "utils.h" #include "utils.h"


namespace megdnn { namespace megdnn {
@@ -75,21 +76,6 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms(
} }


/*! /*!
* \brief whether there is an algorithm from algo_pack() that is available for
* current size
*/
template <class Opr>
bool has_available_algo(
const typename Opr::AlgoBase::SizeArgs& args) {
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(args)) {
return true;
}
}
return false;
}

/*!
* \brief a helper function to get an algorithm match attribute. If require a * \brief a helper function to get an algorithm match attribute. If require a
* algorithm with specified attribute, and the given algorithm match that * algorithm with specified attribute, and the given algorithm match that
* attribute, return the given algorithm. Otherwise return nullptr * attribute, return the given algorithm. Otherwise return nullptr


+ 5
- 8
dnn/src/cuda/conv_bias/conv_nchwqs8.cpp View File

@@ -134,13 +134,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);


AlgoBase::SizeArgs sub_args{
static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0],
config.first[1],
config.first[2],
config.first[3],
config.first[4]};
bool is_relayout_ok = true; bool is_relayout_ok = true;
if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) { if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) {
is_relayout_ok = relayout_format::RelayoutFormatFast::usable( is_relayout_ok = relayout_format::RelayoutFormatFast::usable(
@@ -148,7 +141,11 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
RelayoutFormat::Param::Mode::NCHW4_NCHW); RelayoutFormat::Param::Mode::NCHW4_NCHW);
} }


return is_relayout_ok && has_available_algo<ConvBiasForwardImpl>(sub_args);
return is_relayout_ok &&
has_available_algo<ConvBiasForwardImpl>(
static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2],
config.first[3], config.first[4]);
} }


WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle(


+ 5
- 9
dnn/src/cuda/conv_bias/group_conv.cpp View File

@@ -107,16 +107,12 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available(
auto conv_args = args; auto conv_args = args;
conv_args.dst_layout = &dst_layout; conv_args.dst_layout = &dst_layout;
auto config = prepare_sub_opr(conv_args); auto config = prepare_sub_opr(conv_args);
AlgoBase::SizeArgs sub_args{

bool ret = has_available_algo<ConvBiasForwardImpl>(
static_cast<ConvBiasForwardImpl*>(config.second.get()), static_cast<ConvBiasForwardImpl*>(config.second.get()),
config.first[0],
config.first[1],
config.first[2],
config.first[3],
config.first[4]};

bool ret = has_available_algo<ConvBiasForwardImpl>(sub_args);
return ret;
config.first[0], config.first[1], config.first[2], config.first[3],
config.first[4]);
return ret;
} }


WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle(


+ 3
- 4
dnn/src/cuda/convolution/backward_data/group_conv.cpp View File

@@ -82,11 +82,10 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available(
} }


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};


return has_available_algo<ConvolutionBackwardDataImpl>(sub_args);
return has_available_algo<ConvolutionBackwardDataImpl>(
static_cast<ConvolutionBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }


WorkspaceBundle WorkspaceBundle


+ 3
- 4
dnn/src/cuda/convolution/backward_filter/group_conv.cpp View File

@@ -78,11 +78,10 @@ bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
} }


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};


return has_available_algo<ConvolutionBackwardFilterImpl>(sub_args);
return has_available_algo<ConvolutionBackwardFilterImpl>(
static_cast<ConvolutionBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }


WorkspaceBundle WorkspaceBundle


+ 3
- 4
dnn/src/cuda/convolution3d/backward_data/group_conv.cpp View File

@@ -73,11 +73,10 @@ bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available(
} }


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};


return has_available_algo<Convolution3DBackwardDataImpl>(sub_args);
return has_available_algo<Convolution3DBackwardDataImpl>(
static_cast<Convolution3DBackwardDataImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }


WorkspaceBundle WorkspaceBundle


+ 3
- 4
dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp View File

@@ -77,11 +77,10 @@ bool Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::is_available(
} }


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};


return has_available_algo<Convolution3DBackwardFilterImpl>(sub_args);
return has_available_algo<Convolution3DBackwardFilterImpl>(
static_cast<Convolution3DBackwardFilterImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }


WorkspaceBundle WorkspaceBundle


+ 3
- 4
dnn/src/cuda/convolution3d/forward/group_conv.cpp View File

@@ -80,11 +80,10 @@ bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available(
} }


auto config = prepare_sub_opr(args); auto config = prepare_sub_opr(args);
AlgoBase::SizeArgs sub_args{
static_cast<Convolution3DForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]};


return has_available_algo<Convolution3DForwardImpl>(sub_args);
return has_available_algo<Convolution3DForwardImpl>(
static_cast<Convolution3DForwardImpl*>(config.second.get()),
config.first[0], config.first[1], config.first[2]);
} }


WorkspaceBundle WorkspaceBundle


Loading…
Cancel
Save