diff --git a/src/core/impl/graph/bases.cpp b/src/core/impl/graph/bases.cpp index 71451086..446713fa 100644 --- a/src/core/impl/graph/bases.cpp +++ b/src/core/impl/graph/bases.cpp @@ -14,6 +14,8 @@ using namespace mgb::cg; +MGB_TYPEINFO_OBJ_IMPL(OutputVarsUserData); + GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): m_owner_graph{owner_graph} { diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 288169b7..be9331c3 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( std::unordered_map opr2vars; + using F = VarNode::Flag; + if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) { + for (auto&& i : dest_vars) { + if (!i->contain_flag(F::NO_SYS_MEM_ALLOC | + F::NO_SYS_STATIC_MEM_ALLOC)) { + mgb_assert( + !i->contain_flag( + F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC), + "Can not force graph output dynamic alloc with " + "DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s", + i->cname()); + i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC); + } + i->add_flag(F::NO_MEM_RECLAIM); + } + } for (size_t i = 0; i < out_spec.size(); ++i) { auto&& cb = out_spec[i].second; if (cb) { @@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( init_opr_seq(); #endif // MGB_ENABLE_SUBLINEAR - return {std::move(extra_info), opr_seq}; + return {std::move(extra_info), opr_seq, std::move(dest_vars)}; } std::unique_ptr ComputingGraphImpl::compile_commit( CompileState state) { auto comp_seq = std::make_unique(shared_from_this()); comp_seq->extra_info = std::move(state.extra_info); + comp_seq->set_output_vars(state.dest_vars); auto opr_seq = state.opr_seq; auto&& cmpnt = components(); diff --git a/src/core/impl/graph/cg_impl.h b/src/core/impl/graph/cg_impl.h index 9d6634a9..fb23ffcd 100644 --- a/src/core/impl/graph/cg_impl.h +++ b/src/core/impl/graph/cg_impl.h @@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph { //! extra info that must be set in the ComputingSequence CompSeqExtraInfo extra_info; const OprNodeArray* opr_seq = nullptr; + VarNodeArray dest_vars; }; struct CallbackCallerKey { diff --git a/src/core/include/megbrain/graph/bases.h b/src/core/include/megbrain/graph/bases.h index 9dcbf2c7..f8ee265c 100644 --- a/src/core/include/megbrain/graph/bases.h +++ b/src/core/include/megbrain/graph/bases.h @@ -67,9 +67,10 @@ namespace static_infer { }; using GraphError = mgb::GraphError; +class VarNode; class OperatorNodeBase; class ComputingGraph; - +using VarNodeArray = mgb::SmallVector; /*! * \brief Base class for a node in the graph. * @@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj { } }; +class OutputVarsUserData final : public mgb::UserDataContainer::UserData { + MGB_TYPEINFO_OBJ_DECL; + +private: + VarNodeArray m_output_vars; + +public: + void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); } + const VarNodeArray& get_output_vars() const { return m_output_vars; } +}; + /*! * \brief an object that executes asynchronously */ @@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable, UserDataContainer& user_data() { return m_user_data; } + + void set_output_vars(const VarNodeArray& vars) { + std::shared_ptr ud = + std::make_shared(); + ud->set_output_vars(vars); + m_user_data.add_user_data(ud); + } + + const VarNodeArray& get_output_vars() const { + auto output_vars_pair = + m_user_data.get_user_data(); + return (*(output_vars_pair.first))->get_output_vars(); + } }; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 004db94f..72821a45 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this, //! force dynamic memory alloc for all vars bool force_dynamic_alloc = false; + /*! + * force dynamic memory alloc for output vars which are used as + * CallbackCaller input when call compile() function + */ + bool force_output_dynamic_alloc = false; + //! whether to perform var sanity check on first run bool var_sanity_check_first_run = true; @@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const { std::make_unique(std::forward(args)...))->output(0); } + } // namespace cg } // namespace mgb diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index ef7c58ec..e1031959 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -34,7 +34,7 @@ namespace static_infer { class StaticInferManagerImpl; } -class VarNode; + class VarDevMemDefragmenter; class EagerEvalManager; @@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const { return static_cast(m_flag & flag); } -using VarNodeArray = mgb::SmallVector; using VarNodeSet = ThinHashSet; DType MemAllocPlan::dtype() const { diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index e6b722a2..602a36c6 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) { } } +TEST(TestGraph, DynamicOutput) { + using namespace opr; + REQUIRE_GPU(1); + auto cn0 = CompNode::load("gpu0"); + constexpr size_t C1 = 20, C2 = 20; + constexpr size_t C = C1 + C2; + HostTensorGenerator<> gen; + auto host_opr0 = gen({C}, cn0); + auto graph = ComputingGraph::make(); + graph->options().force_output_dynamic_alloc = true; + SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0); + + auto spl_0 = opr::Split::make( + opr0, Split::Options::make_partition(opr0, 0, {C1, C2})); + + auto sum = opr::add(spl_0[1], spl_0[1]); + + HostTensorND expect_sum, expect_spl_0_0, result_sum, result_spl_0_0; + + auto func1 = graph->compile({make_callback_copy(sum, expect_sum), + make_callback_copy(spl_0[0], expect_spl_0_0)}); + + func1->execute().wait(); + + auto func2 = graph->compile({{sum, nullptr}, {spl_0[0], nullptr}}); + auto&& dest_vars = func2->get_output_vars(); + + func2->execute().wait(); + + result_sum.copy_from(dest_vars[0]->dev_tensor()).sync(); + MGB_ASSERT_TENSOR_NEAR(expect_sum, result_sum, 1e-4); + result_spl_0_0.copy_from(dest_vars[1]->dev_tensor()).sync(); + MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}