Browse Source

test(mgb/opr-mm): add io_remote test

GitOrigin-RevId: c47b6156fe
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
cd8ab9e3a6
1 changed files with 25 additions and 0 deletions
  1. +25
    -0
      src/opr-mm/test/io_remote.cpp

+ 25
- 0
src/opr-mm/test/io_remote.cpp View File

@@ -59,7 +59,31 @@ const auto recv_tag = RemoteIOBase::Type::RECV;
} // anonymous namespace

TEST(TestOprIORemote, Identity) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");

HostTensorGenerator<> gen;
auto host_x = gen({28, 28});
HostTensorND host_y;

auto client = std::make_shared<MockGroupClient>();
auto graph = ComputingGraph::make();

auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
client, {cn1}, host_x->shape(),
host_x->dtype());

auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)});

func->execute();

MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
}

TEST(TestOprIORemote, IdentityMultiThread) {
auto cns = load_multiple_xpus(2);
HostTensorGenerator<> gen;
auto host_x = gen({2, 3}, cns[1]);
@@ -67,6 +91,7 @@ TEST(TestOprIORemote, Identity) {
auto client = std::make_shared<MockGroupClient>();

auto sender = [&]() {
auto graph = ComputingGraph::make();
sys::set_thread_name("sender");
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);


Loading…
Cancel
Save