GitOrigin-RevId: 326fdebb0c
release-1.7
@@ -221,6 +221,8 @@ void DynamicProgrammingSolver::Impl::analyze_edges( | |||||
edges[cur].push_back(ov); | edges[cur].push_back(ov); | ||||
edge2idx[cur].emplace(ov, idx++); | edge2idx[cur].emplace(ov, idx++); | ||||
} | } | ||||
if (cur == 0) | |||||
return; | |||||
cur--; | cur--; | ||||
for (const auto& opr : reverse_adaptor(topo)) { | for (const auto& opr : reverse_adaptor(topo)) { | ||||
for (const auto& i : opr->input()) { | for (const auto& i : opr->input()) { | ||||
@@ -19,19 +19,41 @@ using namespace mgb; | |||||
using namespace gopt; | using namespace gopt; | ||||
using namespace opr; | using namespace opr; | ||||
namespace { | |||||
using OprFormat = SolverBase::OprFormat; | |||||
template <typename Opr> | |||||
bool check_format_aware_opr_valid(const OperatorNodeBase* opr_, OprFormat opr_format) { | |||||
auto&& opr = opr_->cast_final_safe<Opr>(); | |||||
return opr.param().format == opr_format; | |||||
} | |||||
} // namespace | |||||
/* =================== ProfilingBasedSolverSolver ======================*/ | /* =================== ProfilingBasedSolverSolver ======================*/ | ||||
ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler) | ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler) | ||||
: m_profiler{std::move(profiler)} { | : m_profiler{std::move(profiler)} { | ||||
static const ThinHashSet<Typeinfo*> format_aware_oprs = { | |||||
#define cb(_Opr) _Opr::typeinfo() | |||||
cb(Convolution), cb(ConvBiasForward), cb(ConvolutionBackwardData), | |||||
cb(PoolingForward), cb(WarpPerspective), cb(Resize), | |||||
}; | |||||
static const ThinHashMap< | |||||
Typeinfo*, | |||||
thin_function<bool(const OperatorNodeBase*, OprFormat opr_format)>> | |||||
format_aware_opr_validators = { | |||||
#define cb(t) \ | |||||
{opr::t::typeinfo(), std::bind( \ | |||||
check_format_aware_opr_valid<opr::t>, \ | |||||
std::placeholders::_1, std::placeholders::_2)} | |||||
cb(Convolution), | |||||
cb(ConvBiasForward), | |||||
cb(ConvolutionBackwardData), | |||||
cb(PoolingForward), | |||||
cb(WarpPerspective), | |||||
cb(Resize), | |||||
}; | |||||
m_graph_partition_filter = [](const GraphPartition& partition) { | |||||
m_problem_filter = [](const Problem& problem) { | |||||
auto&& base_opr_format = problem.attribute().base_opr_format; | |||||
bool has_format_aware_opr = false; | bool has_format_aware_opr = false; | ||||
for (auto&& opr : partition.all_oprs()) { | |||||
if (!has_format_aware_opr && format_aware_oprs.count(opr->dyn_typeinfo())) { | |||||
for (auto&& opr : problem.graph_partition().all_oprs()) { | |||||
auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); | |||||
if (iter != format_aware_opr_validators.end() && | |||||
iter->second(opr, base_opr_format)) { | |||||
has_format_aware_opr = true; | has_format_aware_opr = true; | ||||
break; | break; | ||||
} | } | ||||
@@ -42,8 +64,7 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile | |||||
ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( | ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( | ||||
const Problem& problem) const { | const Problem& problem) const { | ||||
const auto& partition = problem.graph_partition(); | |||||
if (!m_graph_partition_filter(partition)) | |||||
if (!m_problem_filter(problem)) | |||||
return Solution{}; | return Solution{}; | ||||
return do_solve(problem); | return do_solve(problem); | ||||
} | } | ||||
@@ -49,18 +49,16 @@ public: | |||||
*/ | */ | ||||
class ProfilingBasedSolver : public SolverBase { | class ProfilingBasedSolver : public SolverBase { | ||||
public: | public: | ||||
using GraphPartitionFilter = | |||||
thin_function<bool(const GraphPartition& graph_partition)>; | |||||
using ProblemFilter = thin_function<bool(const Problem&)>; | |||||
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); | ||||
/*! | /*! | ||||
* \note some graph partition (for example, graph partition without format | * \note some graph partition (for example, graph partition without format | ||||
* aware operators like conv, deconv, warp, resize etc.) will be filtered by | * aware operators like conv, deconv, warp, resize etc.) will be filtered by | ||||
* the GraphPartitionFilter, which can reduce the profiling time. */ | |||||
* the ProblemFilter, which can reduce the profiling time. */ | |||||
ProfilingBasedSolver( | ProfilingBasedSolver( | ||||
std::unique_ptr<ProfilerBase> profiler, | |||||
GraphPartitionFilter graph_partition_filter) | |||||
std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter) | |||||
: m_profiler{std::move(profiler)}, | : m_profiler{std::move(profiler)}, | ||||
m_graph_partition_filter{std::move(graph_partition_filter)} {} | |||||
m_problem_filter{std::move(problem_filter)} {} | |||||
virtual ~ProfilingBasedSolver() = default; | virtual ~ProfilingBasedSolver() = default; | ||||
Solution solve(const Problem& problem) const override; | Solution solve(const Problem& problem) const override; | ||||
virtual Solution do_solve(const Problem& problem) const = 0; | virtual Solution do_solve(const Problem& problem) const = 0; | ||||
@@ -69,7 +67,7 @@ protected: | |||||
std::unique_ptr<ProfilerBase> m_profiler; | std::unique_ptr<ProfilerBase> m_profiler; | ||||
private: | private: | ||||
GraphPartitionFilter m_graph_partition_filter; | |||||
ProblemFilter m_problem_filter; | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -81,10 +79,8 @@ public: | |||||
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) | ||||
: ProfilingBasedSolver(std::move(profiler)){}; | : ProfilingBasedSolver(std::move(profiler)){}; | ||||
DynamicProgrammingSolver( | DynamicProgrammingSolver( | ||||
std::unique_ptr<ProfilerBase> profiler, | |||||
GraphPartitionFilter graph_partition_filter) | |||||
: ProfilingBasedSolver( | |||||
std::move(profiler), std::move(graph_partition_filter)){}; | |||||
std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter) | |||||
: ProfilingBasedSolver(std::move(profiler), std::move(problem_filter)){}; | |||||
~DynamicProgrammingSolver() noexcept = default; | ~DynamicProgrammingSolver() noexcept = default; | ||||
Solution do_solve(const Problem& problem) const override; | Solution do_solve(const Problem& problem) const override; | ||||
bool can_solve(const Problem& problem) const override; | bool can_solve(const Problem& problem) const override; | ||||