|
|
@@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph |
|
|
|
size_t WorkspaceLimitGetter::get_workspace_limit( |
|
|
|
ComputingGraph* graph, CompNode cn, size_t old_limit) { |
|
|
|
if (graph->options().imperative_proxy_graph) { |
|
|
|
return old_limit; |
|
|
|
auto impl = WorkspaceLimitHook::get_impl(graph); |
|
|
|
return impl(cn, old_limit); |
|
|
|
} |
|
|
|
if (!graph->options().seq_opt.enable_mem_reuse_alloc) |
|
|
|
return old_limit; |
|
|
@@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/* ================== WorkspaceLimitHook ================== */ |
|
|
|
MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook); |
|
|
|
|
|
|
|
#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA |
|
|
|
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
|
|
|
|
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() |
|
|
|
const { |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
|
|
|
|
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, |
|
|
|
GetWorkspaceLimitImpl /* impl */) { |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
|
|
|
|
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( |
|
|
|
ComputingGraph* /* graph */) { |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
#else |
|
|
|
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { |
|
|
|
m_impl = std::move(impl); |
|
|
|
} |
|
|
|
|
|
|
|
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() |
|
|
|
const { |
|
|
|
return m_impl; |
|
|
|
} |
|
|
|
|
|
|
|
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, |
|
|
|
GetWorkspaceLimitImpl impl) { |
|
|
|
mgb_assert(graph->options().imperative_proxy_graph); |
|
|
|
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; |
|
|
|
graph->options() |
|
|
|
.user_data.get_user_data_or_create<WorkspaceLimitHook>(maker) |
|
|
|
->set_impl(impl); |
|
|
|
} |
|
|
|
|
|
|
|
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( |
|
|
|
ComputingGraph* graph) { |
|
|
|
mgb_assert(graph->options().imperative_proxy_graph); |
|
|
|
auto container = |
|
|
|
graph->options().user_data.get_user_data<WorkspaceLimitHook>(); |
|
|
|
mgb_assert(container.second == 1); |
|
|
|
return container.first[0]->get_impl(); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |