GitOrigin-RevId: 5d47ed263f
release-1.1
@@ -483,13 +483,13 @@ void init_graph_rt(py::module m) { | |||||
}, | }, | ||||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | ||||
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | |||||
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) { | |||||
SymbolVarArray sinputs; | SymbolVarArray sinputs; | ||||
for (auto i : inputs) { | for (auto i : inputs) { | ||||
sinputs.emplace_back(i); | sinputs.emplace_back(i); | ||||
} | } | ||||
static_assert(!std::is_reference<decltype(callback)>::value); | static_assert(!std::is_reference<decltype(callback)>::value); | ||||
opr::OutputCallback::Param param{std::move(callback), borrow}; | |||||
opr::OutputCallback::Param param{std::move(callback), borrow, prefer_host_value}; | |||||
auto output = opr::OutputCallback::make(std::move(param), sinputs); | auto output = opr::OutputCallback::make(std::move(param), sinputs); | ||||
return output.node(); | return output.node(); | ||||
}; | }; | ||||
@@ -519,7 +519,7 @@ void init_graph_rt(py::module m) { | |||||
hv_with_event.second->record(); | hv_with_event.second->record(); | ||||
p->set(std::move(hv_with_event)); | p->set(std::move(hv_with_event)); | ||||
}; | }; | ||||
return output_callback(std::move(f), std::move(inputs), true); | |||||
return output_callback(std::move(f), std::move(inputs), true, true); | |||||
}); | }); | ||||
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | ||||
@@ -144,13 +144,24 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { | |||||
prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP); | prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP); | ||||
SmallVector<NodeProp::DepType> dep_types(input().size(), | SmallVector<NodeProp::DepType> dep_types(input().size(), | ||||
NodeProp::DepType::DEV_COMP_ORDER); | NodeProp::DepType::DEV_COMP_ORDER); | ||||
dep_types[0] = NodeProp::DepType::DEV_VALUE; | |||||
using IT = cg::static_infer::InferType; | |||||
auto host_value_avail = [&]() -> bool { | |||||
auto inp = input(0); | |||||
auto it = owner_graph()->static_infer_manager().get_infer_type(inp).value; | |||||
return it & (IT::CONST | IT::RT_STATIC | IT::MISSING_INP); | |||||
}; | |||||
m_use_host_value = m_param.prefer_host_value && host_value_avail(); | |||||
dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE : NodeProp::DepType::DEV_VALUE; | |||||
prop->reset_dep_type(input(), dep_types); | prop->reset_dep_type(input(), dep_types); | ||||
return prop; | return prop; | ||||
} | } | ||||
void OutputCallback::scn_do_execute() { | void OutputCallback::scn_do_execute() { | ||||
m_param.callback(input(0)->dev_tensor()); | |||||
if (m_use_host_value) { | |||||
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | |||||
} else { | |||||
m_param.callback(input(0)->dev_tensor()); | |||||
} | |||||
} | } | ||||
cg::OperatorNodeBase* OutputCallback::shallow_copy( | cg::OperatorNodeBase* OutputCallback::shallow_copy( | ||||
@@ -60,7 +60,8 @@ public: | |||||
using callback_t = thin_function<void(DeviceTensorND)>; | using callback_t = thin_function<void(DeviceTensorND)>; | ||||
struct Param { | struct Param { | ||||
callback_t callback; | callback_t callback; | ||||
bool borrow = false; | |||||
bool borrow = false; // do not obtain shared ownership on DeviceTensorND | |||||
bool prefer_host_value = false; // use host value when possible | |||||
}; | }; | ||||
OutputCallback(Param param, | OutputCallback(Param param, | ||||
const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
@@ -81,6 +82,7 @@ protected: | |||||
NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
private: | private: | ||||
Param m_param; | Param m_param; | ||||
mutable bool m_use_host_value; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(NopCallback, cg::OperatorNodeBase) // { | MGB_DEFINE_OPR_CLASS(NopCallback, cg::OperatorNodeBase) // { | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -50,6 +51,27 @@ TEST(TestOprUtility, OutputCallback) { | |||||
MGB_ASSERT_TENSOR_EQ(hy, *hx); | MGB_ASSERT_TENSOR_EQ(hy, *hx); | ||||
} | } | ||||
TEST(TestOprUtility, OutputCallbackPreferHost) { | |||||
HostTensorGenerator<> gen; | |||||
auto hx = gen({2, 3}); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, hx); | |||||
x = opr::GetVarShape::make(x); | |||||
HostTensorND hy; | |||||
auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);}; | |||||
opr::OutputCallback::Param param{callback}; | |||||
param.prefer_host_value = true; | |||||
auto dummy = opr::OutputCallback::make(param, x); | |||||
auto y = opr::VirtualDep::make({x, dummy}); | |||||
ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}}; | |||||
auto func = graph->compile(outspec); | |||||
func->execute(); | |||||
ASSERT_TRUE(hy.comp_node() == CompNode::default_cpu()); | |||||
ASSERT_EQ(hy.ptr<int>()[0], 2); | |||||
ASSERT_EQ(hy.ptr<int>()[1], 3); | |||||
} | |||||
TEST(TestOprUtility, NopCallback) { | TEST(TestOprUtility, NopCallback) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto hx = gen({2, 3}); | auto hx = gen({2, 3}); | ||||