GitOrigin-RevId: bd69a82d4c
release-1.7
@@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph | |||||
size_t WorkspaceLimitGetter::get_workspace_limit( | size_t WorkspaceLimitGetter::get_workspace_limit( | ||||
ComputingGraph* graph, CompNode cn, size_t old_limit) { | ComputingGraph* graph, CompNode cn, size_t old_limit) { | ||||
if (graph->options().imperative_proxy_graph) { | if (graph->options().imperative_proxy_graph) { | ||||
return old_limit; | |||||
auto impl = WorkspaceLimitHook::get_impl(graph); | |||||
return impl(cn, old_limit); | |||||
} | } | ||||
if (!graph->options().seq_opt.enable_mem_reuse_alloc) | if (!graph->options().seq_opt.enable_mem_reuse_alloc) | ||||
return old_limit; | return old_limit; | ||||
@@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop( | |||||
} | } | ||||
} | } | ||||
/* ================== WorkspaceLimitHook ================== */ | |||||
MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook); | |||||
#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA | |||||
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { | |||||
mgb_assert(false); | |||||
} | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() | |||||
const { | |||||
mgb_assert(false); | |||||
} | |||||
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, | |||||
GetWorkspaceLimitImpl /* impl */) { | |||||
mgb_assert(false); | |||||
} | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( | |||||
ComputingGraph* /* graph */) { | |||||
mgb_assert(false); | |||||
} | |||||
#else | |||||
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { | |||||
m_impl = std::move(impl); | |||||
} | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() | |||||
const { | |||||
return m_impl; | |||||
} | |||||
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, | |||||
GetWorkspaceLimitImpl impl) { | |||||
mgb_assert(graph->options().imperative_proxy_graph); | |||||
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; | |||||
graph->options() | |||||
.user_data.get_user_data_or_create<WorkspaceLimitHook>(maker) | |||||
->set_impl(impl); | |||||
} | |||||
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( | |||||
ComputingGraph* graph) { | |||||
mgb_assert(graph->options().imperative_proxy_graph); | |||||
auto container = | |||||
graph->options().user_data.get_user_data<WorkspaceLimitHook>(); | |||||
mgb_assert(container.second == 1); | |||||
return container.first[0]->get_impl(); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -621,8 +621,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | |||||
if (prof.empty()) | if (prof.empty()) | ||||
return {{}, rst}; | return {{}, rst}; | ||||
size_t workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||||
auto target_attr = extract_algo_attribute(selected_strategy); | auto target_attr = extract_algo_attribute(selected_strategy); | ||||
bool skip_by_negative = false; | bool skip_by_negative = false; | ||||
bool skip_by_workspace = false; | |||||
for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
auto attr_of_algo = static_cast<megdnn::Algorithm::Attribute>(i.attribute); | auto attr_of_algo = static_cast<megdnn::Algorithm::Attribute>(i.attribute); | ||||
bool contain_attr_all_positive = | bool contain_attr_all_positive = | ||||
@@ -631,13 +634,18 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | |||||
static_cast<bool>(attr_of_algo & target_attr.second); | static_cast<bool>(attr_of_algo & target_attr.second); | ||||
if (contain_attr_all_positive) { | if (contain_attr_all_positive) { | ||||
if (!contain_attr_any_negative) { | if (!contain_attr_any_negative) { | ||||
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); | |||||
return {algo_desc, rst}; | |||||
if (i.workspace <= workspace_limit) { | |||||
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); | |||||
return {algo_desc, rst}; | |||||
} | |||||
skip_by_workspace = true; | |||||
} else { | } else { | ||||
skip_by_negative = true; | skip_by_negative = true; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
if (skip_by_workspace) | |||||
return {}; | |||||
std::string layouts_str = | std::string layouts_str = | ||||
format_fixlayouts<Opr>(m_fastrun_layouts, arity_in, arity_out); | format_fixlayouts<Opr>(m_fastrun_layouts, arity_in, arity_out); | ||||
@@ -316,6 +316,22 @@ protected: | |||||
typename Super::NodeProp* do_make_node_prop() const override; | typename Super::NodeProp* do_make_node_prop() const override; | ||||
}; | }; | ||||
class WorkspaceLimitHook final : public UserDataContainer::UserData { | |||||
MGB_TYPEINFO_OBJ_DECL; | |||||
public: | |||||
using GetWorkspaceLimitImpl = thin_function<size_t(CompNode, size_t)>; | |||||
WorkspaceLimitHook() = default; | |||||
~WorkspaceLimitHook() = default; | |||||
static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl); | |||||
static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph); | |||||
private: | |||||
void set_impl(GetWorkspaceLimitImpl impl); | |||||
const GetWorkspaceLimitImpl& get_impl() const; | |||||
GetWorkspaceLimitImpl m_impl; | |||||
}; | |||||
} // namespace intl | } // namespace intl | ||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||