|
|
@@ -15,6 +15,7 @@ |
|
|
|
#include "../op_trait.h" |
|
|
|
#include "../proxy_graph_detail.h" |
|
|
|
#include "megbrain/opr/mm_handler.h" |
|
|
|
#include "megbrain/utils/hash.h" |
|
|
|
#endif // MGB_ENABLE_OPR_MM |
|
|
|
|
|
|
|
#include "megbrain/imperative/ops/collective_comm.h" |
|
|
@@ -52,6 +53,45 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.fallback(); |
|
|
|
} // anonymous namespace |
|
|
|
|
|
|
|
|
|
|
|
bool CollectiveComm::is_same_st(const Hashable& another) const{ |
|
|
|
auto* comm_opr = another.try_cast_final<CollectiveComm>(); |
|
|
|
if(!comm_opr){ |
|
|
|
return false; |
|
|
|
} |
|
|
|
return as_tuple() == comm_opr->as_tuple(); |
|
|
|
} |
|
|
|
|
|
|
|
size_t CollectiveComm::hash() const{ |
|
|
|
XXHash xxhash{}; |
|
|
|
auto append = [&xxhash](auto field){ |
|
|
|
auto hash_val = HashTrait<decltype(field)>::eval(field); |
|
|
|
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); |
|
|
|
}; |
|
|
|
append(key); |
|
|
|
append(nr_devices); |
|
|
|
append(rank); |
|
|
|
append(is_root); |
|
|
|
append(local_grad); |
|
|
|
append(addr); |
|
|
|
append(port); |
|
|
|
append(mode); |
|
|
|
append(backend); |
|
|
|
append(comp_node); |
|
|
|
return xxhash.digest(); |
|
|
|
} |
|
|
|
|
|
|
|
#else |
|
|
|
|
|
|
|
bool CollectiveComm::is_same_st(const Hashable& another) const{ |
|
|
|
return OpDef::is_same_st(another); |
|
|
|
} |
|
|
|
|
|
|
|
size_t CollectiveComm::hash() const{ |
|
|
|
return OpDef::hash(); |
|
|
|
} |
|
|
|
|
|
|
|
#endif // MGB_ENABLE_OPR_MM |
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); |
|
|
|