diff --git a/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp index 94f6908b..07dbd544 100644 --- a/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp +++ b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp @@ -221,6 +221,8 @@ void DynamicProgrammingSolver::Impl::analyze_edges( edges[cur].push_back(ov); edge2idx[cur].emplace(ov, idx++); } + if (cur == 0) + return; cur--; for (const auto& opr : reverse_adaptor(topo)) { for (const auto& i : opr->input()) { diff --git a/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp index 88da239c..e7039c99 100644 --- a/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp +++ b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp @@ -19,19 +19,41 @@ using namespace mgb; using namespace gopt; using namespace opr; +namespace { +using OprFormat = SolverBase::OprFormat; +template +bool check_format_aware_opr_valid(const OperatorNodeBase* opr_, OprFormat opr_format) { + auto&& opr = opr_->cast_final_safe(); + return opr.param().format == opr_format; +} +} // namespace + /* =================== ProfilingBasedSolverSolver ======================*/ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr profiler) : m_profiler{std::move(profiler)} { - static const ThinHashSet format_aware_oprs = { -#define cb(_Opr) _Opr::typeinfo() - cb(Convolution), cb(ConvBiasForward), cb(ConvolutionBackwardData), - cb(PoolingForward), cb(WarpPerspective), cb(Resize), - }; + static const ThinHashMap< + Typeinfo*, + thin_function> + format_aware_opr_validators = { +#define cb(t) \ + {opr::t::typeinfo(), std::bind( \ + check_format_aware_opr_valid, \ + std::placeholders::_1, std::placeholders::_2)} + cb(Convolution), + cb(ConvBiasForward), + cb(ConvolutionBackwardData), + cb(PoolingForward), + cb(WarpPerspective), + cb(Resize), + }; - m_graph_partition_filter = [](const GraphPartition& partition) { + m_problem_filter = [](const Problem& problem) { + auto&& base_opr_format = problem.attribute().base_opr_format; bool has_format_aware_opr = false; - for (auto&& opr : partition.all_oprs()) { - if (!has_format_aware_opr && format_aware_oprs.count(opr->dyn_typeinfo())) { + for (auto&& opr : problem.graph_partition().all_oprs()) { + auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); + if (iter != format_aware_opr_validators.end() && + iter->second(opr, base_opr_format)) { has_format_aware_opr = true; break; } @@ -42,8 +64,7 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr profile ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( const Problem& problem) const { - const auto& partition = problem.graph_partition(); - if (!m_graph_partition_filter(partition)) + if (!m_problem_filter(problem)) return Solution{}; return do_solve(problem); } diff --git a/src/gopt/include/megbrain/gopt/solver.h b/src/gopt/include/megbrain/gopt/solver.h index 2eed1abe..4fa36960 100644 --- a/src/gopt/include/megbrain/gopt/solver.h +++ b/src/gopt/include/megbrain/gopt/solver.h @@ -49,18 +49,16 @@ public: */ class ProfilingBasedSolver : public SolverBase { public: - using GraphPartitionFilter = - thin_function; + using ProblemFilter = thin_function; ProfilingBasedSolver(std::unique_ptr profiler); /*! * \note some graph partition (for example, graph partition without format * aware operators like conv, deconv, warp, resize etc.) will be filtered by - * the GraphPartitionFilter, which can reduce the profiling time. */ + * the ProblemFilter, which can reduce the profiling time. */ ProfilingBasedSolver( - std::unique_ptr profiler, - GraphPartitionFilter graph_partition_filter) + std::unique_ptr profiler, ProblemFilter problem_filter) : m_profiler{std::move(profiler)}, - m_graph_partition_filter{std::move(graph_partition_filter)} {} + m_problem_filter{std::move(problem_filter)} {} virtual ~ProfilingBasedSolver() = default; Solution solve(const Problem& problem) const override; virtual Solution do_solve(const Problem& problem) const = 0; @@ -69,7 +67,7 @@ protected: std::unique_ptr m_profiler; private: - GraphPartitionFilter m_graph_partition_filter; + ProblemFilter m_problem_filter; }; /*! @@ -81,10 +79,8 @@ public: DynamicProgrammingSolver(std::unique_ptr profiler) : ProfilingBasedSolver(std::move(profiler)){}; DynamicProgrammingSolver( - std::unique_ptr profiler, - GraphPartitionFilter graph_partition_filter) - : ProfilingBasedSolver( - std::move(profiler), std::move(graph_partition_filter)){}; + std::unique_ptr profiler, ProblemFilter problem_filter) + : ProfilingBasedSolver(std::move(profiler), std::move(problem_filter)){}; ~DynamicProgrammingSolver() noexcept = default; Solution do_solve(const Problem& problem) const override; bool can_solve(const Problem& problem) const override;