Browse Source

fix(mgb): replace if_constexpr with runtime function to avoid potential

bug

GitOrigin-RevId: 27fe093d50
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
88898e63a5
2 changed files with 69 additions and 30 deletions
  1. +60
    -30
      src/opr/impl/search_policy/profiler.cpp
  2. +9
    -0
      src/opr/include/megbrain/opr/search_policy/profiler.h

+ 60
- 30
src/opr/impl/search_policy/profiler.cpp View File

@@ -155,6 +155,64 @@ double TimedProfiler<Opr>::init_timeout_setting() {
std::tuple_cat(__VA_ARGS__))

template <typename Opr>
void TimedProfiler<Opr>::preprocess(const TensorLayoutArray&,
const megdnn::SmallVector<DeviceTensorND>&,
intl::UniqPtrWithCN<Opr>&,
megdnn::Workspace&,
std::array<TensorLayout, arity>&,
std::array<DeviceTensorND, arity_in>&,
PreprocessFilter<Opr>&) {
// Opr is neither convbias nor convolution.This function do nothing.
}

//! convbias
template <>
void TimedProfiler<megdnn::ConvBias>::preprocess(
const TensorLayoutArray& preprocessed_layout,
const SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<megdnn::ConvBias>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<megdnn::ConvBias>& prep_flt) {
if (!preprocessed_layout.empty()) {
auto&& pf = prep_flt;
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace),
std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn(),
inp_val[2].as_megdnn()),
array_skip<arity_in - 1>(layouts));
}
}

//! convolution
template <>
void TimedProfiler<megdnn::ConvolutionForward>::preprocess(
const TensorLayoutArray& preprocessed_layout,
const megdnn::SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<megdnn::ConvolutionForward>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<megdnn::ConvolutionForward>& prep_flt) {
if (!preprocessed_layout.empty()) {
auto&& pf = prep_flt;
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace),
std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn()),
array_skip<2>(layouts));
}
}

template <typename Opr>
typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
const TParam& raw_param) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_impl")))
@@ -258,36 +316,8 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
}

PreprocessFilter<Opr> prep_flt;
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
if (!preprocessed_layout.empty()) {
auto&& pf = _(prep_flt);
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
if_constexpr<opr_contain_bias<Opr>()>(
//! convbias
[&](auto __) {
APPLY(__(megdnn_opr)
->exec_preprocess(args..., &pf,
mdn_workspace),
std::forward_as_tuple(layouts[0],
inp_val[1].as_megdnn(),
inp_val[2].as_megdnn()),
array_skip<arity_in - 1>(layouts));
},
//! Convolution
[&](auto __) {
APPLY(__(megdnn_opr)
->exec_preprocess(args..., &pf,
mdn_workspace),
std::forward_as_tuple(layouts[0],
inp_val[1].as_megdnn()),
array_skip<arity_in>(layouts));
});
}
});
preprocess(preprocessed_layout, flt_val, megdnn_opr, mdn_workspace, layouts,
inp_val, prep_flt);

RealTimer timer;
auto ev_start = cn.create_event(CompNode::Event::NEED_TIMER),


+ 9
- 0
src/opr/include/megbrain/opr/search_policy/profiler.h View File

@@ -16,6 +16,8 @@
#include "megbrain/utils/timer.h"
#include "megbrain/system.h"
#include "megbrain/comp_node.h"
#include "megbrain/tensor.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "megdnn/basic_types.h"
#include "megdnn/oprs.h"
@@ -149,6 +151,13 @@ private:
static const double timeout_setting;

static double init_timeout_setting();
static void preprocess(const megdnn::TensorLayoutArray& preprocessed_layout,
const SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<Opr>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<Opr>& prep_flt);
static TResult prof_impl(const TParam& raw_param);
static void prof_init_device(const TParam& raw_param);
};


Loading…
Cancel
Save