GitOrigin-RevId: e1dac3c919
tags/v1.8.0
@@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||||
auto group_client = std::make_shared<opr::GroupClientProxy>( | auto group_client = std::make_shared<opr::GroupClientProxy>( | ||||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
mgb_assert(!recv.shape.empty()); | |||||
TensorShape shape; | |||||
for (auto&& dim : recv.shape) { | |||||
shape[shape.ndim++] = dim; | |||||
} | |||||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype, | |||||
recv.key, inputs[0], *graph, group_client, config, shape, recv.dtype, | |||||
recv.backend)); | recv.backend)); | ||||
} | } | ||||
@@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) { | |||||
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | ||||
auto def = imperative::RemoteRecv::make( | auto def = imperative::RemoteRecv::make( | ||||
"io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | ||||
TensorShape{vector_size}, dtype::Float32(), "nccl"); | |||||
std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl"); | |||||
auto inp = Tensor::make(*hnd); | auto inp = Tensor::make(*hnd); | ||||
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | ||||
HostTensorND host_v; | HostTensorND host_v; | ||||
@@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { | |||||
MgbUI32Attr:$port, | MgbUI32Attr:$port, | ||||
MgbUI32Attr:$rank_from, | MgbUI32Attr:$rank_from, | ||||
MgbCompNodeAttr:$cn, | MgbCompNodeAttr:$cn, | ||||
MgbTensorShapeAttr:$shape, | |||||
MgbArrayAttr<MgbI32Attr>:$shape, | |||||
MgbDTypeAttr:$dtype, | MgbDTypeAttr:$dtype, | ||||
MgbStringAttr:$backend | MgbStringAttr:$backend | ||||
); | ); | ||||