@@ -14,6 +14,8 @@ | |||
using namespace mgb::cg; | |||
MGB_TYPEINFO_OBJ_IMPL(OutputVarsUserData); | |||
GraphNodeBase::GraphNodeBase(ComputingGraph *owner_graph): | |||
m_owner_graph{owner_graph} | |||
{ | |||
@@ -563,6 +563,22 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
std::unordered_map<CallbackCallerKey, CallbackCallerVal, | |||
CallbackCallerKey::Hash> | |||
opr2vars; | |||
using F = VarNode::Flag; | |||
if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) { | |||
for (auto&& i : dest_vars) { | |||
if (!i->contain_flag(F::NO_SYS_MEM_ALLOC | | |||
F::NO_SYS_STATIC_MEM_ALLOC)) { | |||
mgb_assert( | |||
!i->contain_flag( | |||
F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC), | |||
"Can not force graph output dynamic alloc with " | |||
"DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s", | |||
i->cname()); | |||
i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC); | |||
} | |||
i->add_flag(F::NO_MEM_RECLAIM); | |||
} | |||
} | |||
for (size_t i = 0; i < out_spec.size(); ++i) { | |||
auto&& cb = out_spec[i].second; | |||
if (cb) { | |||
@@ -641,13 +657,14 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
init_opr_seq(); | |||
#endif // MGB_ENABLE_SUBLINEAR | |||
return {std::move(extra_info), opr_seq}; | |||
return {std::move(extra_info), opr_seq, std::move(dest_vars)}; | |||
} | |||
std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit( | |||
CompileState state) { | |||
auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this()); | |||
comp_seq->extra_info = std::move(state.extra_info); | |||
comp_seq->set_output_vars(state.dest_vars); | |||
auto opr_seq = state.opr_seq; | |||
auto&& cmpnt = components(); | |||
@@ -38,6 +38,7 @@ class ComputingGraphImpl final : public ComputingGraph { | |||
//! extra info that must be set in the ComputingSequence | |||
CompSeqExtraInfo extra_info; | |||
const OprNodeArray* opr_seq = nullptr; | |||
VarNodeArray dest_vars; | |||
}; | |||
struct CallbackCallerKey { | |||
@@ -67,9 +67,10 @@ namespace static_infer { | |||
}; | |||
using GraphError = mgb::GraphError; | |||
class VarNode; | |||
class OperatorNodeBase; | |||
class ComputingGraph; | |||
using VarNodeArray = mgb::SmallVector<VarNode*>; | |||
/*! | |||
* \brief Base class for a node in the graph. | |||
* | |||
@@ -102,6 +103,17 @@ class GraphNodeBase: public json::Serializable, public NonCopyableObj { | |||
} | |||
}; | |||
class OutputVarsUserData final : public mgb::UserDataContainer::UserData { | |||
MGB_TYPEINFO_OBJ_DECL; | |||
private: | |||
VarNodeArray m_output_vars; | |||
public: | |||
void set_output_vars(VarNodeArray vars) { m_output_vars = std::move(vars); } | |||
const VarNodeArray& get_output_vars() const { return m_output_vars; } | |||
}; | |||
/*! | |||
* \brief an object that executes asynchronously | |||
*/ | |||
@@ -165,6 +177,19 @@ class AsyncExecutable : public json::Serializable, | |||
UserDataContainer& user_data() { | |||
return m_user_data; | |||
} | |||
void set_output_vars(const VarNodeArray& vars) { | |||
std::shared_ptr<OutputVarsUserData> ud = | |||
std::make_shared<OutputVarsUserData>(); | |||
ud->set_output_vars(vars); | |||
m_user_data.add_user_data(ud); | |||
} | |||
const VarNodeArray& get_output_vars() const { | |||
auto output_vars_pair = | |||
m_user_data.get_user_data<OutputVarsUserData>(); | |||
return (*(output_vars_pair.first))->get_output_vars(); | |||
} | |||
}; | |||
@@ -399,6 +399,12 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||
//! force dynamic memory alloc for all vars | |||
bool force_dynamic_alloc = false; | |||
/*! | |||
* force dynamic memory alloc for output vars which are used as | |||
* CallbackCaller input when call compile() function | |||
*/ | |||
bool force_output_dynamic_alloc = false; | |||
//! whether to perform var sanity check on first run | |||
bool var_sanity_check_first_run = true; | |||
@@ -657,6 +663,7 @@ SymbolVar SymbolVar::insert_single_output_opr(Args &&...args) const { | |||
std::make_unique<Node>(std::forward<Args>(args)...))->output(0); | |||
} | |||
} // namespace cg | |||
} // namespace mgb | |||
@@ -34,7 +34,7 @@ namespace static_infer { | |||
class StaticInferManagerImpl; | |||
} | |||
class VarNode; | |||
class VarDevMemDefragmenter; | |||
class EagerEvalManager; | |||
@@ -685,7 +685,6 @@ bool VarNode::contain_flag(Flag flag) const { | |||
return static_cast<bool>(m_flag & flag); | |||
} | |||
using VarNodeArray = mgb::SmallVector<VarNode*>; | |||
using VarNodeSet = ThinHashSet<VarNode*>; | |||
DType MemAllocPlan::dtype() const { | |||
@@ -2287,4 +2287,39 @@ TEST(TestGraph, CallbackCaller) { | |||
} | |||
} | |||
TEST(TestGraph, DynamicOutput) { | |||
using namespace opr; | |||
REQUIRE_GPU(1); | |||
auto cn0 = CompNode::load("gpu0"); | |||
constexpr size_t C1 = 20, C2 = 20; | |||
constexpr size_t C = C1 + C2; | |||
HostTensorGenerator<> gen; | |||
auto host_opr0 = gen({C}, cn0); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().force_output_dynamic_alloc = true; | |||
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0); | |||
auto spl_0 = opr::Split::make( | |||
opr0, Split::Options::make_partition(opr0, 0, {C1, C2})); | |||
auto sum = opr::add(spl_0[1], spl_0[1]); | |||
HostTensorND expect_sum, expect_spl_0_0, result_sum, result_spl_0_0; | |||
auto func1 = graph->compile({make_callback_copy(sum, expect_sum), | |||
make_callback_copy(spl_0[0], expect_spl_0_0)}); | |||
func1->execute().wait(); | |||
auto func2 = graph->compile({{sum, nullptr}, {spl_0[0], nullptr}}); | |||
auto&& dest_vars = func2->get_output_vars(); | |||
func2->execute().wait(); | |||
result_sum.copy_from(dest_vars[0]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(expect_sum, result_sum, 1e-4); | |||
result_spl_0_0.copy_from(dest_vars[1]->dev_tensor()).sync(); | |||
MGB_ASSERT_TENSOR_NEAR(expect_spl_0_0, result_spl_0_0, 1e-4); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |