Browse Source

feat(mgb): enable output dynamic memory alloc

GitOrigin-RevId: c809629034
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
f3378100f9
7 changed files with 90 additions and 4 deletions
  1. +2
    -0
      src/core/impl/graph/bases.cpp
  2. +18
    -1
      src/core/impl/graph/cg_impl.cpp
  3. +1
    -0
      src/core/impl/graph/cg_impl.h
  4. +26
    -1
      src/core/include/megbrain/graph/bases.h
  5. +7
    -0
      src/core/include/megbrain/graph/cg.h
  6. +1
    -2
      src/core/include/megbrain/graph/var_node.h
  7. +35
    -0
      src/core/test/graph/misc.cpp

+ 2
- 0
src/core/impl/graph/bases.cpp View File

@@ -14,6 +14,8 @@


using namespace mgb::cg; using namespace mgb::cg;


MGB_TYPEINFO_OBJ_IMPL(OutputVarsUserData);

GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph):
m_owner_graph{owner_graph} m_owner_graph{owner_graph}
{ {


+ 18
- 1
src/core/impl/graph/cg_impl.cpp View File

@@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
std::unordered_map<CallbackCallerKey, CallbackCallerVal, std::unordered_map<CallbackCallerKey, CallbackCallerVal,
CallbackCallerKey::Hash> CallbackCallerKey::Hash>
opr2vars; opr2vars;
using F = VarNode::Flag;
if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) {
for (auto&& i : dest_vars) {
if (!i->contain_flag(F::NO_SYS_MEM_ALLOC |
F::NO_SYS_STATIC_MEM_ALLOC)) {
mgb_assert(
!i->contain_flag(
F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC),
"Can not force graph output dynamic alloc with "
"DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s",
i->cname());
i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC);
}
i->add_flag(F::NO_MEM_RECLAIM);
}
}
for (size_t i = 0; i < out_spec.size(); ++i) { for (size_t i = 0; i < out_spec.size(); ++i) {
auto&& cb = out_spec[i].second; auto&& cb = out_spec[i].second;
if (cb) { if (cb) {
@@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
init_opr_seq(); init_opr_seq();
#endif // MGB_ENABLE_SUBLINEAR #endif // MGB_ENABLE_SUBLINEAR


return {std::move(extra_info), opr_seq};
return {std::move(extra_info), opr_seq, std::move(dest_vars)};
} }


std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit( std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
CompileState state) { CompileState state) {
auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this()); auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
comp_seq->extra_info = std::move(state.extra_info); comp_seq->extra_info = std::move(state.extra_info);
comp_seq->set_output_vars(state.dest_vars);
auto opr_seq = state.opr_seq; auto opr_seq = state.opr_seq;
auto&& cmpnt = components(); auto&& cmpnt = components();




+ 1
- 0
src/core/impl/graph/cg_impl.h View File

@@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph {
//! extra info that must be set in the ComputingSequence //! extra info that must be set in the ComputingSequence
CompSeqExtraInfo extra_info; CompSeqExtraInfo extra_info;
const OprNodeArray* opr_seq = nullptr; const OprNodeArray* opr_seq = nullptr;
VarNodeArray dest_vars;
}; };


struct CallbackCallerKey { struct CallbackCallerKey {


+ 26
- 1
src/core/include/megbrain/graph/bases.h View File

@@ -67,9 +67,10 @@ namespace static_infer {
}; };


using GraphError = mgb::GraphError; using GraphError = mgb::GraphError;
class VarNode;
class OperatorNodeBase; class OperatorNodeBase;
class ComputingGraph; class ComputingGraph;
using VarNodeArray = mgb::SmallVector<VarNode*>;
/*! /*!
* \brief Base class for a node in the graph. * \brief Base class for a node in the graph.
* *
@@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj {
} }
}; };


class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;

private:
VarNodeArray m_output_vars;

public:
void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); }
const VarNodeArray& get_output_vars() const { return m_output_vars; }
};

/*! /*!
* \brief an object that executes asynchronously * \brief an object that executes asynchronously
*/ */
@@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable,
UserDataContainer& user_data() { UserDataContainer& user_data() {
return m_user_data; return m_user_data;
} }

void set_output_vars(const VarNodeArray& vars) {
std::shared_ptr<OutputVarsUserData> ud =
std::make_shared<OutputVarsUserData>();
ud->set_output_vars(vars);
m_user_data.add_user_data(ud);
}

const VarNodeArray& get_output_vars() const {
auto output_vars_pair =
m_user_data.get_user_data<OutputVarsUserData>();
return (*(output_vars_pair.first))->get_output_vars();
}
}; };






+ 7
- 0
src/core/include/megbrain/graph/cg.h View File

@@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! force dynamic memory alloc for all vars //! force dynamic memory alloc for all vars
bool force_dynamic_alloc = false; bool force_dynamic_alloc = false;


/*!
* force dynamic memory alloc for output vars which are used as
* CallbackCaller input when call compile() function
*/
bool force_output_dynamic_alloc = false;

//! whether to perform var sanity check on first run //! whether to perform var sanity check on first run
bool var_sanity_check_first_run = true; bool var_sanity_check_first_run = true;


@@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const {
std::make_unique<Node>(std::forward<Args>(args)...))->output(0); std::make_unique<Node>(std::forward<Args>(args)...))->output(0);
} }



} // namespace cg } // namespace cg
} // namespace mgb } // namespace mgb




+ 1
- 2
src/core/include/megbrain/graph/var_node.h View File

@@ -34,7 +34,7 @@ namespace static_infer {
class StaticInferManagerImpl; class StaticInferManagerImpl;
} }


class VarNode;
class VarDevMemDefragmenter; class VarDevMemDefragmenter;
class EagerEvalManager; class EagerEvalManager;


@@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const {
return static_cast<bool>(m_flag & flag); return static_cast<bool>(m_flag & flag);
} }


using VarNodeArray = mgb::SmallVector<VarNode*>;
using VarNodeSet = ThinHashSet<VarNode*>; using VarNodeSet = ThinHashSet<VarNode*>;


DType MemAllocPlan::dtype() const { DType MemAllocPlan::dtype() const {


+ 35
- 0
src/core/test/graph/misc.cpp View File

@@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) {
} }
} }


TEST(TestGraph, DynamicOutput) {
using namespace opr;
REQUIRE_GPU(1);
auto cn0 = CompNode::load("gpu0");
constexpr size_t C1 = 20, C2 = 20;
constexpr size_t C = C1 + C2;
HostTensorGenerator<> gen;
auto host_opr0 = gen({C}, cn0);
auto graph = ComputingGraph::make();
graph->options().force_output_dynamic_alloc = true;
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0);

auto spl_0 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 0, {C1, C2}));

auto sum = opr::add(spl_0[1], spl_0[1]);

HostTensorND expect_sum, expect_spl_0_0, result_sum, result_spl_0_0;

auto func1 = graph->compile({make_callback_copy(sum, expect_sum),
make_callback_copy(spl_0[0], expect_spl_0_0)});

func1->execute().wait();

auto func2 = graph->compile({{sum, nullptr}, {spl_0[0], nullptr}});
auto&& dest_vars = func2->get_output_vars();

func2->execute().wait();

result_sum.copy_from(dest_vars[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(expect_sum, result_sum, 1e-4);
result_spl_0_0.copy_from(dest_vars[1]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save