Browse Source

feat(imperative): impl hashable for SendRecv and add virtual input for Recv

GitOrigin-RevId: 5e8c27ac81
release-1.1
Megvii Engine Team 4 years ago
parent
commit
064c774f05
4 changed files with 89 additions and 1 deletions
  1. +38
    -1
      imperative/src/impl/ops/io_remote.cpp
  2. +14
    -0
      imperative/src/include/megbrain/imperative/ops/io_remote.h
  3. +26
    -0
      src/opr-mm/impl/io_remote.cpp
  4. +11
    -0
      src/opr-mm/include/megbrain/opr/io_remote.h

+ 38
- 1
imperative/src/impl/ops/io_remote.cpp View File

@@ -46,7 +46,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, *graph, group_client, OperatorNodeConfig{recv.cn},
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn},
recv.shape, recv.dtype));
}

@@ -60,6 +60,43 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
} // anonymous namespace
#endif // MGB_ENABLE_OPR_MM

bool RemoteSend::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteSend>().as_tuple();
}

size_t RemoteSend::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_to);
return xxhash.digest();
}

bool RemoteRecv::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple();
}

size_t RemoteRecv::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_from);
append(cn.to_string());
append(dtype.handle());
append(shape.to_string());
return xxhash.digest();
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);



+ 14
- 0
imperative/src/include/megbrain/imperative/ops/io_remote.h View File

@@ -31,6 +31,13 @@ public:
std::string addr;
uint32_t port;
uint32_t rank_to;

size_t hash() const override;
bool is_same_st(const Hashable& another) const override;

auto as_tuple() const{
return std::tuple(key, addr, port, rank_to);
}
};

class RemoteRecv : public OpDefImplBase<RemoteRecv> {
@@ -55,6 +62,13 @@ public:
CompNode cn;
TensorShape shape;
DType dtype;

size_t hash() const override;
bool is_same_st(const Hashable& another) const override;

auto as_tuple() const{
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string());
}
};

} // namespace imperative


+ 26
- 0
src/opr-mm/impl/io_remote.cpp View File

@@ -151,6 +151,23 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
add_equivalence_component<ScalarHash<void*>>(this);
}

RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) :
Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) {
m_key = key;
m_group_client = group_client;

add_input({var});
add_output(None)
->dtype(dtype)
.add_flag(VarNode::Flag::NO_MEM_RECLAIM)
.add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
add_equivalence_component<ScalarHash<void*>>(this);
}

SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
@@ -160,6 +177,15 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
return opr->output(0);
}

SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
key, var.node(), graph, group_client, config, shape, dtype));
return opr->output(0);
}

void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();


+ 11
- 0
src/opr-mm/include/megbrain/opr/io_remote.h View File

@@ -77,12 +77,23 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);

RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);

static SymbolVar make(
const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);

static SymbolVar make(
const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);

private:
const TensorShape m_shape;
const DType m_dtype;


Loading…
Cancel
Save