GitOrigin-RevId: 326fdebb0c
release-1.7
@@ -221,6 +221,8 @@ void DynamicProgrammingSolver::Impl::analyze_edges( | |||
edges[cur].push_back(ov); | |||
edge2idx[cur].emplace(ov, idx++); | |||
} | |||
if (cur == 0) | |||
return; | |||
cur--; | |||
for (const auto& opr : reverse_adaptor(topo)) { | |||
for (const auto& i : opr->input()) { | |||
@@ -19,19 +19,41 @@ using namespace mgb; | |||
using namespace gopt; | |||
using namespace opr; | |||
namespace { | |||
using OprFormat = SolverBase::OprFormat; | |||
template <typename Opr> | |||
bool check_format_aware_opr_valid(const OperatorNodeBase* opr_, OprFormat opr_format) { | |||
auto&& opr = opr_->cast_final_safe<Opr>(); | |||
return opr.param().format == opr_format; | |||
} | |||
} // namespace | |||
/* =================== ProfilingBasedSolverSolver ======================*/ | |||
ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler) | |||
: m_profiler{std::move(profiler)} { | |||
static const ThinHashSet<Typeinfo*> format_aware_oprs = { | |||
#define cb(_Opr) _Opr::typeinfo() | |||
cb(Convolution), cb(ConvBiasForward), cb(ConvolutionBackwardData), | |||
cb(PoolingForward), cb(WarpPerspective), cb(Resize), | |||
}; | |||
static const ThinHashMap< | |||
Typeinfo*, | |||
thin_function<bool(const OperatorNodeBase*, OprFormat opr_format)>> | |||
format_aware_opr_validators = { | |||
#define cb(t) \ | |||
{opr::t::typeinfo(), std::bind( \ | |||
check_format_aware_opr_valid<opr::t>, \ | |||
std::placeholders::_1, std::placeholders::_2)} | |||
cb(Convolution), | |||
cb(ConvBiasForward), | |||
cb(ConvolutionBackwardData), | |||
cb(PoolingForward), | |||
cb(WarpPerspective), | |||
cb(Resize), | |||
}; | |||
m_graph_partition_filter = [](const GraphPartition& partition) { | |||
m_problem_filter = [](const Problem& problem) { | |||
auto&& base_opr_format = problem.attribute().base_opr_format; | |||
bool has_format_aware_opr = false; | |||
for (auto&& opr : partition.all_oprs()) { | |||
if (!has_format_aware_opr && format_aware_oprs.count(opr->dyn_typeinfo())) { | |||
for (auto&& opr : problem.graph_partition().all_oprs()) { | |||
auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); | |||
if (iter != format_aware_opr_validators.end() && | |||
iter->second(opr, base_opr_format)) { | |||
has_format_aware_opr = true; | |||
break; | |||
} | |||
@@ -42,8 +64,7 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile | |||
ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( | |||
const Problem& problem) const { | |||
const auto& partition = problem.graph_partition(); | |||
if (!m_graph_partition_filter(partition)) | |||
if (!m_problem_filter(problem)) | |||
return Solution{}; | |||
return do_solve(problem); | |||
} | |||
@@ -49,18 +49,16 @@ public: | |||
*/ | |||
class ProfilingBasedSolver : public SolverBase { | |||
public: | |||
using GraphPartitionFilter = | |||
thin_function<bool(const GraphPartition& graph_partition)>; | |||
using ProblemFilter = thin_function<bool(const Problem&)>; | |||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | |||
/*! | |||
* \note some graph partition (for example, graph partition without format | |||
* aware operators like conv, deconv, warp, resize etc.) will be filtered by | |||
* the GraphPartitionFilter, which can reduce the profiling time. */ | |||
* the ProblemFilter, which can reduce the profiling time. */ | |||
ProfilingBasedSolver( | |||
std::unique_ptr<ProfilerBase> profiler, | |||
GraphPartitionFilter graph_partition_filter) | |||
std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter) | |||
: m_profiler{std::move(profiler)}, | |||
m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||
m_problem_filter{std::move(problem_filter)} {} | |||
virtual ~ProfilingBasedSolver() = default; | |||
Solution solve(const Problem& problem) const override; | |||
virtual Solution do_solve(const Problem& problem) const = 0; | |||
@@ -69,7 +67,7 @@ protected: | |||
std::unique_ptr<ProfilerBase> m_profiler; | |||
private: | |||
GraphPartitionFilter m_graph_partition_filter; | |||
ProblemFilter m_problem_filter; | |||
}; | |||
/*! | |||
@@ -81,10 +79,8 @@ public: | |||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | |||
: ProfilingBasedSolver(std::move(profiler)){}; | |||
DynamicProgrammingSolver( | |||
std::unique_ptr<ProfilerBase> profiler, | |||
GraphPartitionFilter graph_partition_filter) | |||
: ProfilingBasedSolver( | |||
std::move(profiler), std::move(graph_partition_filter)){}; | |||
std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter) | |||
: ProfilingBasedSolver(std::move(profiler), std::move(problem_filter)){}; | |||
~DynamicProgrammingSolver() noexcept = default; | |||
Solution do_solve(const Problem& problem) const override; | |||
bool can_solve(const Problem& problem) const override; | |||