Browse Source

fix(mge/imperative): impl hashable for coll-comm

GitOrigin-RevId: 76ab16a89b
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
3bbfef3009
3 changed files with 51 additions and 2 deletions
  1. +40
    -0
      imperative/src/impl/ops/collective_comm.cpp
  2. +9
    -0
      imperative/src/include/megbrain/imperative/ops/collective_comm.h
  3. +2
    -2
      imperative/src/include/megbrain/imperative/ops/opr_attr.h

+ 40
- 0
imperative/src/impl/ops/collective_comm.cpp View File

@@ -15,6 +15,7 @@
#include "../op_trait.h"
#include "../proxy_graph_detail.h"
#include "megbrain/opr/mm_handler.h"
#include "megbrain/utils/hash.h"
#endif // MGB_ENABLE_OPR_MM

#include "megbrain/imperative/ops/collective_comm.h"
@@ -52,6 +53,45 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // anonymous namespace


bool CollectiveComm::is_same_st(const Hashable& another) const{
auto* comm_opr = another.try_cast_final<CollectiveComm>();
if(!comm_opr){
return false;
}
return as_tuple() == comm_opr->as_tuple();
}

size_t CollectiveComm::hash() const{
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(nr_devices);
append(rank);
append(is_root);
append(local_grad);
append(addr);
append(port);
append(mode);
append(backend);
append(comp_node);
return xxhash.digest();
}

#else

bool CollectiveComm::is_same_st(const Hashable& another) const{
return OpDef::is_same_st(another);
}

size_t CollectiveComm::hash() const{
return OpDef::hash();
}

#endif // MGB_ENABLE_OPR_MM

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);


+ 9
- 0
imperative/src/include/megbrain/imperative/ops/collective_comm.h View File

@@ -52,6 +52,15 @@ public:
DType dtype;
std::string backend;
std::string comp_node;

size_t hash() const override;

bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, nr_devices, rank, is_root,
local_grad, addr, port, mode, dtype,
backend, comp_node);
}
};

} // namespace imperative


+ 2
- 2
imperative/src/include/megbrain/imperative/ops/opr_attr.h View File

@@ -45,8 +45,8 @@ public:

std::string repr() const;

bool is_same_st(const Hashable& rhs) const;
size_t hash() const;
bool is_same_st(const Hashable& rhs) const override;
size_t hash() const override;
};

} // namespace imperative


Loading…
Cancel
Save