Browse Source

fix(imperative/fastrun): set workspace limit for imperative rt

GitOrigin-RevId: 474dc691a3
release-1.7
Megvii Engine Team 3 years ago
parent
commit
d9a9d9d49e
5 changed files with 34 additions and 10 deletions
  1. +13
    -0
      imperative/src/impl/proxy_graph.cpp
  2. +5
    -0
      imperative/src/impl/proxy_graph.h
  3. +7
    -0
      src/core/impl/comp_node/cuda/comp_node.cpp
  4. +3
    -0
      src/core/include/megbrain/comp_node.h
  5. +6
    -10
      src/opr/impl/internal/megdnn_opr_wrapper.cpp

+ 13
- 0
imperative/src/impl/proxy_graph.cpp View File

@@ -15,6 +15,7 @@
#include "megbrain/graph/static_infer.h" #include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { const OpDef& opdef, const SmallVector<Tensor*>& inputs) {
SmallVector<LogicalTensorDesc> ret; SmallVector<LogicalTensorDesc> ret;
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true); do_shape_infer(true);
for (auto&& i : m_cur_opr->usable_output()) { for (auto&& i : m_cur_opr->usable_output()) {
mgb_assert(i->dtype().valid() && i->comp_node().valid()); mgb_assert(i->dtype().valid() && i->comp_node().valid());
@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor(
// get proxy opr // get proxy opr
auto proxy = m_cur_opr; auto proxy = m_cur_opr;


auto get_workspace_size = [=](CompNode cn, size_t old_limit) {
size_t limit = 0;
for (auto&& var : workspaces) {
limit += var->dtype().size(var->shape().total_nr_elems());
}
return limit;
};
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size);
do_shape_infer(true); do_shape_infer(true);


size_t j = 0; size_t j = 0;
@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
const SmallVector<MemoryDesc>& inputs_mems) { const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors); auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr); CUR_OPR_GUARD(opr);
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true); do_shape_infer(true);
SmallVector<MemoryDesc> outputs; SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces; SmallVector<MemoryDesc> workspaces;


+ 5
- 0
imperative/src/impl/proxy_graph.h View File

@@ -27,6 +27,11 @@ public:
static std::unique_ptr<MegBrainError> get_async_error() { static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error); return std::move(tm_async_error);
} }
static size_t get_workspace_limit(CompNode cn, size_t old_limit) {
size_t free = cn.get_free_mem();
size_t lmt = cn.get_max_block_size_available();
return std::max(lmt, free);
}


/********************** Physical Tensor API **********************/ /********************** Physical Tensor API **********************/




+ 7
- 0
src/core/impl/comp_node/cuda/comp_node.cpp View File

@@ -273,6 +273,13 @@ public:
activate(); activate();
return m_mem_alloc->get_max_block_size_available(); return m_mem_alloc->get_max_block_size_available();
} }

size_t get_free_mem() override {
m_env.cuda_env().activate();
size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
return free;
}
#endif #endif


Locator locator() override { return m_locator; } Locator locator() override { return m_locator; }


+ 3
- 0
src/core/include/megbrain/comp_node.h View File

@@ -336,6 +336,8 @@ public:
size_t get_max_block_size_available() const { size_t get_max_block_size_available() const {
return m_impl->get_max_block_size_available(); return m_impl->get_max_block_size_available();
} }

size_t get_free_mem() const { return m_impl->get_free_mem(); }
#endif #endif


//! change to another stream on the same memory node //! change to another stream on the same memory node
@@ -519,6 +521,7 @@ protected:
} }
virtual size_t get_used_memory() { return 0; } virtual size_t get_used_memory() { return 0; }
virtual size_t get_max_block_size_available() { return 0; } virtual size_t get_max_block_size_available() { return 0; }
virtual size_t get_free_mem() { return 0; }
#endif #endif


virtual Locator locator() = 0; virtual Locator locator() = 0;


+ 6
- 10
src/opr/impl/internal/megdnn_opr_wrapper.cpp View File

@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false); mgb_assert(false);
} }


const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
mgb_assert(false); mgb_assert(false);
} }


void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */,
GetWorkspaceLimitImpl /* impl */) {
void WorkspaceLimitHook::set_impl(
ComputingGraph* /* graph */, GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false); mgb_assert(false);
} }


@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl = std::move(impl); m_impl = std::move(impl);
} }


const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
return m_impl; return m_impl;
} }


void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
GetWorkspaceLimitImpl impl) {
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl) {
mgb_assert(graph->options().imperative_proxy_graph); mgb_assert(graph->options().imperative_proxy_graph);
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); };
graph->options() graph->options()
@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* graph) { ComputingGraph* graph) {
mgb_assert(graph->options().imperative_proxy_graph); mgb_assert(graph->options().imperative_proxy_graph);
auto container =
graph->options().user_data.get_user_data<WorkspaceLimitHook>();
auto container = graph->options().user_data.get_user_data<WorkspaceLimitHook>();
mgb_assert(container.second == 1); mgb_assert(container.second == 1);
return container.first[0]->get_impl(); return container.first[0]->get_impl();
} }


Loading…
Cancel
Save