Browse Source

feat(mge/graph): add `modify_opr_algo_strategy_inplace` for fast-run

GitOrigin-RevId: 034cf58b2a
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
594fa722bd
2 changed files with 29 additions and 0 deletions
  1. +13
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +16
    -0
      imperative/python/src/graph_rt.cpp

+ 13
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -288,6 +288,19 @@ def optimize_for_inference(dest_vars, **kwargs):
return _wrap(res_vars)


def modify_opr_algo_strategy_inplace(dest_vars, strategy: str):
"""
C++ graph version of :func:`~.set_execution_strategy`. Used to inplacely modify
dumped graph's fast-run strategy.

:param dest_vars: list of output vars in the computing graph.
:param strategy: fast-run algorithms strategy.

"""
dest_vars = _unwrap(dest_vars)
_imperative_rt.modify_opr_algo_strategy_inplace(dest_vars, strategy)


CompGraphDumpResult = collections.namedtuple(
"CompGraphDumpResult",
[


+ 16
- 0
imperative/python/src/graph_rt.cpp View File

@@ -32,6 +32,7 @@ namespace ser = mgb::serialization;

using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;

namespace {
class _CompGraphProfilerImpl {
@@ -257,6 +258,21 @@ void init_graph_rt(py::module m) {
return vars;
});

m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, const std::string& strategy) {
_AlgoStrategy stg;
const std::unordered_map<std::string,std::function<void()>> m{
{"HEURISTIC", [&](){ stg = _AlgoStrategy::HEURISTIC; }},
{"HEURISTIC_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::HEURISTIC_REPRODUCIBLE; }},
{"PROFILE", [&](){ stg = _AlgoStrategy::PROFILE; }},
{"PROFILE_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::PROFILE_REPRODUCIBLE; }},
{"PROFILE_HEURISTIC", [&](){ stg = _AlgoStrategy::PROFILE_HEURISTIC; }},
};
auto it = m.find(strategy);
mgb_assert(it != m.end(), "Invalid strategy string!");
it->second();
mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, stg);
});

m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes;
auto on_opr = [&](cg::OperatorNodeBase *opr) {


Loading…
Cancel
Save