|
|
@@ -13,6 +13,7 @@ |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
|
#include "megbrain/opr/utility.h" |
|
|
|
#include "megbrain/opr/tensor_manip.h" |
|
|
|
#include "megbrain/test/helper.h" |
|
|
|
|
|
|
|
using namespace mgb; |
|
|
@@ -50,6 +51,27 @@ TEST(TestOprUtility, OutputCallback) { |
|
|
|
MGB_ASSERT_TENSOR_EQ(hy, *hx); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprUtility, OutputCallbackPreferHost) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto hx = gen({2, 3}); |
|
|
|
auto graph = ComputingGraph::make(); |
|
|
|
auto x = opr::Host2DeviceCopy::make(*graph, hx); |
|
|
|
x = opr::GetVarShape::make(x); |
|
|
|
HostTensorND hy; |
|
|
|
auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);}; |
|
|
|
opr::OutputCallback::Param param{callback}; |
|
|
|
param.prefer_host_value = true; |
|
|
|
auto dummy = opr::OutputCallback::make(param, x); |
|
|
|
auto y = opr::VirtualDep::make({x, dummy}); |
|
|
|
|
|
|
|
ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}}; |
|
|
|
auto func = graph->compile(outspec); |
|
|
|
func->execute(); |
|
|
|
ASSERT_TRUE(hy.comp_node() == CompNode::default_cpu()); |
|
|
|
ASSERT_EQ(hy.ptr<int>()[0], 2); |
|
|
|
ASSERT_EQ(hy.ptr<int>()[1], 3); |
|
|
|
} |
|
|
|
|
|
|
|
TEST(TestOprUtility, NopCallback) { |
|
|
|
HostTensorGenerator<> gen; |
|
|
|
auto hx = gen({2, 3}); |
|
|
|