Browse Source

fix(mgb/opr): fix fastrun workspace limit for imperative rt

GitOrigin-RevId: bd69a82d4c
release-1.7
Megvii Engine Team 3 years ago
parent
commit
a09a2b730d
3 changed files with 79 additions and 3 deletions
  1. +53
    -1
      src/opr/impl/internal/megdnn_opr_wrapper.cpp
  2. +10
    -2
      src/opr/impl/search_policy/algo_chooser.cpp
  3. +16
    -0
      src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h

+ 53
- 1
src/opr/impl/internal/megdnn_opr_wrapper.cpp View File

@@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph
size_t WorkspaceLimitGetter::get_workspace_limit(
ComputingGraph* graph, CompNode cn, size_t old_limit) {
if (graph->options().imperative_proxy_graph) {
return old_limit;
auto impl = WorkspaceLimitHook::get_impl(graph);
return impl(cn, old_limit);
}
if (!graph->options().seq_opt.enable_mem_reuse_alloc)
return old_limit;
@@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop(
}
}

/* ================== WorkspaceLimitHook ================== */
MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook);

#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}

const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
mgb_assert(false);
}

void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */,
GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}

const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* /* graph */) {
mgb_assert(false);
}
#else
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl = std::move(impl);
}

const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
return m_impl;
}

void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
GetWorkspaceLimitImpl impl) {
mgb_assert(graph->options().imperative_proxy_graph);
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); };
graph->options()
.user_data.get_user_data_or_create<WorkspaceLimitHook>(maker)
->set_impl(impl);
}

const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* graph) {
mgb_assert(graph->options().imperative_proxy_graph);
auto container =
graph->options().user_data.get_user_data<WorkspaceLimitHook>();
mgb_assert(container.second == 1);
return container.first[0]->get_impl();
}
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 10
- 2
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -621,8 +621,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
if (prof.empty())
return {{}, rst};

size_t workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto target_attr = extract_algo_attribute(selected_strategy);
bool skip_by_negative = false;
bool skip_by_workspace = false;
for (auto&& i : prof) {
auto attr_of_algo = static_cast<megdnn::Algorithm::Attribute>(i.attribute);
bool contain_attr_all_positive =
@@ -631,13 +634,18 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
static_cast<bool>(attr_of_algo & target_attr.second);
if (contain_attr_all_positive) {
if (!contain_attr_any_negative) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return {algo_desc, rst};
if (i.workspace <= workspace_limit) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return {algo_desc, rst};
}
skip_by_workspace = true;
} else {
skip_by_negative = true;
}
}
}
if (skip_by_workspace)
return {};

std::string layouts_str =
format_fixlayouts<Opr>(m_fastrun_layouts, arity_in, arity_out);


+ 16
- 0
src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h View File

@@ -316,6 +316,22 @@ protected:
typename Super::NodeProp* do_make_node_prop() const override;
};

class WorkspaceLimitHook final : public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;

public:
using GetWorkspaceLimitImpl = thin_function<size_t(CompNode, size_t)>;
WorkspaceLimitHook() = default;
~WorkspaceLimitHook() = default;
static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl);
static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph);

private:
void set_impl(GetWorkspaceLimitImpl impl);
const GetWorkspaceLimitImpl& get_impl() const;
GetWorkspaceLimitImpl m_impl;
};

} // namespace intl
} // namespace opr
} // namespace mgb


Loading…
Cancel
Save