Browse Source

fix(mgb): fix fastrun compnode

GitOrigin-RevId: 8db93facb9
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
5d9ac970ab
3 changed files with 7 additions and 4 deletions
  1. +2
    -1
      src/opr/impl/search_policy/algo_chooser.cpp
  2. +4
    -2
      src/opr/impl/search_policy/profiler.cpp
  3. +1
    -1
      src/opr/include/megbrain/opr/search_policy/profiler.h

+ 2
- 1
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -853,7 +853,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo(
src.to_string().c_str());
param.dtypes[i] = src.dtype.enumv();
}
param.comp_node_loc = m_cn.locator();
param.comp_node_physical = m_cn.locator();
param.comp_node_logical = m_cn.locator_logical();
mgb_assert(param.shapes.size() == m_fastrun_layouts.size());
for (size_t i = 0; i < param.shapes.size(); ++i)
param.shapes[i] = m_fastrun_layouts[i];


+ 4
- 2
src/opr/impl/search_policy/profiler.cpp View File

@@ -222,7 +222,8 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
mgb_assert(miopen_algo_search_enabled, "MIOpen algo search not enabled");
#endif
auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
CompNode cn =
CompNode::load(param.comp_node_physical, param.comp_node_logical);
auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn);
std::array<TensorLayout, arity> layouts;

@@ -395,7 +396,8 @@ void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) {
megcore::enableMIOpenAlgoSearch(true);
#endif
auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
CompNode cn =
CompNode::load(param.comp_node_physical, param.comp_node_logical);
// wait for cuda init, so its time does not get accounted in timeout
cn.sync();
MIDOUT_E


+ 1
- 1
src/opr/include/megbrain/opr/search_policy/profiler.h View File

@@ -122,7 +122,7 @@ public:
ExecutionPolicyBlob execution_policy;
size_t workspace;
megdnn::DTypeEnum dtypes[arity];
CompNode::Locator comp_node_loc;
CompNode::Locator comp_node_physical, comp_node_logical;
TensorShapeArray shapes;
typename Opr::Param opr_param;
bool allow_weight_preprocess;


Loading…
Cancel
Save