@@ -14,6 +14,8 @@ | |||||
using namespace mgb::cg; | using namespace mgb::cg; | ||||
MGB_TYPEINFO_OBJ_IMPL(OutputVarsUserData); | |||||
GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): | GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): | ||||
m_owner_graph{owner_graph} | m_owner_graph{owner_graph} | ||||
{ | { | ||||
@@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
std::unordered_map<CallbackCallerKey, CallbackCallerVal, | std::unordered_map<CallbackCallerKey, CallbackCallerVal, | ||||
CallbackCallerKey::Hash> | CallbackCallerKey::Hash> | ||||
opr2vars; | opr2vars; | ||||
using F = VarNode::Flag; | |||||
if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) { | |||||
for (auto&& i : dest_vars) { | |||||
if (!i->contain_flag(F::NO_SYS_MEM_ALLOC | | |||||
F::NO_SYS_STATIC_MEM_ALLOC)) { | |||||
mgb_assert( | |||||
!i->contain_flag( | |||||
F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC), | |||||
"Can not force graph output dynamic alloc with " | |||||
"DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s", | |||||
i->cname()); | |||||
i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC); | |||||
} | |||||
i->add_flag(F::NO_MEM_RECLAIM); | |||||
} | |||||
} | |||||
for (size_t i = 0; i < out_spec.size(); ++i) { | for (size_t i = 0; i < out_spec.size(); ++i) { | ||||
auto&& cb = out_spec[i].second; | auto&& cb = out_spec[i].second; | ||||
if (cb) { | if (cb) { | ||||
@@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
init_opr_seq(); | init_opr_seq(); | ||||
#endif // MGB_ENABLE_SUBLINEAR | #endif // MGB_ENABLE_SUBLINEAR | ||||
return {std::move(extra_info), opr_seq}; | |||||
return {std::move(extra_info), opr_seq, std::move(dest_vars)}; | |||||
} | } | ||||
std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit( | std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit( | ||||
CompileState state) { | CompileState state) { | ||||
auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this()); | auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this()); | ||||
comp_seq->extra_info = std::move(state.extra_info); | comp_seq->extra_info = std::move(state.extra_info); | ||||
comp_seq->set_output_vars(state.dest_vars); | |||||
auto opr_seq = state.opr_seq; | auto opr_seq = state.opr_seq; | ||||
auto&& cmpnt = components(); | auto&& cmpnt = components(); | ||||
@@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph { | |||||
//! extra info that must be set in the ComputingSequence | //! extra info that must be set in the ComputingSequence | ||||
CompSeqExtraInfo extra_info; | CompSeqExtraInfo extra_info; | ||||
const OprNodeArray* opr_seq = nullptr; | const OprNodeArray* opr_seq = nullptr; | ||||
VarNodeArray dest_vars; | |||||
}; | }; | ||||
struct CallbackCallerKey { | struct CallbackCallerKey { | ||||
@@ -67,9 +67,10 @@ namespace static_infer { | |||||
}; | }; | ||||
using GraphError = mgb::GraphError; | using GraphError = mgb::GraphError; | ||||
class VarNode; | |||||
class OperatorNodeBase; | class OperatorNodeBase; | ||||
class ComputingGraph; | class ComputingGraph; | ||||
using VarNodeArray = mgb::SmallVector<VarNode*>; | |||||
/*! | /*! | ||||
* \brief Base class for a node in the graph. | * \brief Base class for a node in the graph. | ||||
* | * | ||||
@@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj { | |||||
} | } | ||||
}; | }; | ||||
class OutputVarsUserData final : public mgb::UserDataContainer::UserData { | |||||
MGB_TYPEINFO_OBJ_DECL; | |||||
private: | |||||
VarNodeArray m_output_vars; | |||||
public: | |||||
void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); } | |||||
const VarNodeArray& get_output_vars() const { return m_output_vars; } | |||||
}; | |||||
/*! | /*! | ||||
* \brief an object that executes asynchronously | * \brief an object that executes asynchronously | ||||
*/ | */ | ||||
@@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable, | |||||
UserDataContainer& user_data() { | UserDataContainer& user_data() { | ||||
return m_user_data; | return m_user_data; | ||||
} | } | ||||
void set_output_vars(const VarNodeArray& vars) { | |||||
std::shared_ptr<OutputVarsUserData> ud = | |||||
std::make_shared<OutputVarsUserData>(); | |||||
ud->set_output_vars(vars); | |||||
m_user_data.add_user_data(ud); | |||||
} | |||||
const VarNodeArray& get_output_vars() const { | |||||
auto output_vars_pair = | |||||
m_user_data.get_user_data<OutputVarsUserData>(); | |||||
return (*(output_vars_pair.first))->get_output_vars(); | |||||
} | |||||
}; | }; | ||||
@@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
//! force dynamic memory alloc for all vars | //! force dynamic memory alloc for all vars | ||||
bool force_dynamic_alloc = false; | bool force_dynamic_alloc = false; | ||||
/*! | |||||
* force dynamic memory alloc for output vars which are used as | |||||
* CallbackCaller input when call compile() function | |||||
*/ | |||||
bool force_output_dynamic_alloc = false; | |||||
//! whether to perform var sanity check on first run | //! whether to perform var sanity check on first run | ||||
bool var_sanity_check_first_run = true; | bool var_sanity_check_first_run = true; | ||||
@@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const { | |||||
std::make_unique<Node>(std::forward<Args>(args)...))->output(0); | std::make_unique<Node>(std::forward<Args>(args)...))->output(0); | ||||
} | } | ||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -34,7 +34,7 @@ namespace static_infer { | |||||
class StaticInferManagerImpl; | class StaticInferManagerImpl; | ||||
} | } | ||||
class VarNode; | |||||
class VarDevMemDefragmenter; | class VarDevMemDefragmenter; | ||||
class EagerEvalManager; | class EagerEvalManager; | ||||
@@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const { | |||||
return static_cast<bool>(m_flag & flag); | return static_cast<bool>(m_flag & flag); | ||||
} | } | ||||
using VarNodeArray = mgb::SmallVector<VarNode*>; | |||||
using VarNodeSet = ThinHashSet<VarNode*>; | using VarNodeSet = ThinHashSet<VarNode*>; | ||||
DType MemAllocPlan::dtype() const { | DType MemAllocPlan::dtype() const { | ||||
@@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) { | |||||
} | } | ||||
} | } | ||||
TEST(TestGraph, DynamicOutput) { | |||||
using namespace opr; | |||||
REQUIRE_GPU(1); | |||||
auto cn0 = CompNode::load("gpu0"); | |||||
constexpr size_t C1 = 20, C2 = 20; | |||||
constexpr size_t C = C1 + C2; | |||||
HostTensorGenerator<> gen; | |||||
auto host_opr0 = gen({C}, cn0); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().force_output_dynamic_alloc = true; | |||||
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0); | |||||
auto spl_0 = opr::Split::make( | |||||
opr0, Split::Options::make_partition(opr0, 0, {C1, C2})); | |||||
auto sum = opr::add(spl_0[1], spl_0[1]); | |||||
HostTensorND expect_sum, expect_spl_0_0, result_sum, result_spl_0_0; | |||||
auto func1 = graph->compile({make_callback_copy(sum, expect_sum), | |||||
make_callback_copy(spl_0[0], expect_spl_0_0)}); | |||||
func1->execute().wait(); | |||||
auto func2 = graph->compile({{sum, nullptr}, {spl_0[0], nullptr}}); | |||||
auto&& dest_vars = func2->get_output_vars(); | |||||
func2->execute().wait(); | |||||
result_sum.copy_from(dest_vars[0]->dev_tensor()).sync(); | |||||
MGB_ASSERT_TENSOR_NEAR(expect_sum, result_sum, 1e-4); | |||||
result_spl_0_0.copy_from(dest_vars[1]->dev_tensor()).sync(); | |||||
MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |