|
|
@@ -32,6 +32,7 @@ namespace ser = mgb::serialization; |
|
|
|
|
|
|
|
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; |
|
|
|
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; |
|
|
|
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; |
|
|
|
|
|
|
|
namespace { |
|
|
|
class _CompGraphProfilerImpl { |
|
|
@@ -257,6 +258,21 @@ void init_graph_rt(py::module m) { |
|
|
|
return vars; |
|
|
|
}); |
|
|
|
|
|
|
|
m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, const std::string& strategy) { |
|
|
|
_AlgoStrategy stg; |
|
|
|
const std::unordered_map<std::string,std::function<void()>> m{ |
|
|
|
{"HEURISTIC", [&](){ stg = _AlgoStrategy::HEURISTIC; }}, |
|
|
|
{"HEURISTIC_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::HEURISTIC_REPRODUCIBLE; }}, |
|
|
|
{"PROFILE", [&](){ stg = _AlgoStrategy::PROFILE; }}, |
|
|
|
{"PROFILE_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::PROFILE_REPRODUCIBLE; }}, |
|
|
|
{"PROFILE_HEURISTIC", [&](){ stg = _AlgoStrategy::PROFILE_HEURISTIC; }}, |
|
|
|
}; |
|
|
|
auto it = m.find(strategy); |
|
|
|
mgb_assert(it != m.end(), "Invalid strategy string!"); |
|
|
|
it->second(); |
|
|
|
mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, stg); |
|
|
|
}); |
|
|
|
|
|
|
|
m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) { |
|
|
|
std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes; |
|
|
|
auto on_opr = [&](cg::OperatorNodeBase *opr) { |
|
|
|