diff --git a/src/opr-mm/test/io_remote.cpp b/src/opr-mm/test/io_remote.cpp index c2ba7ad4..a25384fd 100644 --- a/src/opr-mm/test/io_remote.cpp +++ b/src/opr-mm/test/io_remote.cpp @@ -59,7 +59,31 @@ const auto recv_tag = RemoteIOBase::Type::RECV; } // anonymous namespace TEST(TestOprIORemote, Identity) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x = gen({28, 28}); + HostTensorND host_y; + + auto client = std::make_shared(); auto graph = ComputingGraph::make(); + + auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); + auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client); + auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(), + client, {cn1}, host_x->shape(), + host_x->dtype()); + + auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); + + func->execute(); + + MGB_ASSERT_TENSOR_EQ(*host_x, host_y); +} + +TEST(TestOprIORemote, IdentityMultiThread) { auto cns = load_multiple_xpus(2); HostTensorGenerator<> gen; auto host_x = gen({2, 3}, cns[1]); @@ -67,6 +91,7 @@ TEST(TestOprIORemote, Identity) { auto client = std::make_shared(); auto sender = [&]() { + auto graph = ComputingGraph::make(); sys::set_thread_name("sender"); auto x = opr::Host2DeviceCopy::make(*graph, host_x), xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);