Browse Source

fix(mge/imperative): fix op specializations and tuple hash

GitOrigin-RevId: e3df93fd5b
release-1.2
Megvii Engine Team 4 years ago
parent
commit
e9c036ccba
2 changed files with 34 additions and 21 deletions
  1. +19
    -14
      imperative/src/impl/ops/autogen.cpp
  2. +15
    -7
      imperative/src/impl/ops/specializations.cpp

+ 19
- 14
imperative/src/impl/ops/autogen.cpp View File

@@ -18,23 +18,28 @@ using namespace megdnn;
// FIXME: remove this when mgb::hash support tuple_hash
namespace mgb {
namespace {
template<typename T, size_t ...Ns>
auto tail(T t, std::index_sequence<Ns...>) {
return std::make_tuple(std::get<Ns+1>(t)...);

struct HashWrapper {
size_t hash;
constexpr operator size_t() {return hash;}

constexpr HashWrapper operator+(HashWrapper rhs) {
// NOTE: use a + b + c + d, not a + (b + (c + d)) !!!
return {hash * 20141203 + rhs.hash};
}
};

template <typename... Args>
constexpr size_t hash_many(const Args&... args) {
return (... + HashWrapper{mgb::hash(args)});
}

} // anonymous namespace

template<typename T, typename ...Args>
class HashTrait<std::tuple<T, Args...>> {
constexpr static size_t length = sizeof...(Args);
public:
struct HashTrait<std::tuple<T, Args...>> {
static size_t eval(const std::tuple<T, Args...> &t) {
const T& val = std::get<0>(t);
if constexpr (!length) {
return mgb::hash(val);
} else {
return mgb::hash_pair_combine(mgb::hash(val),
mgb::hash(tail(t, std::make_index_sequence<length - 1>{})));
}
return std::apply(hash_many<T, Args...>, t);
}
};
} // namespace mgb
@@ -43,4 +48,4 @@ namespace mgb::imperative {

#include "./opdef.cpp.inl"

} // namespace mgb::imperative
} // namespace mgb::imperative

+ 15
- 7
imperative/src/impl/ops/specializations.cpp View File

@@ -158,7 +158,13 @@ auto apply_on_var_node(
}
}

OP_TRAIT_REG(Reduce, Reduce)
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Reduce>();
return Reduce::make(node->param());
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reduce
@@ -439,12 +445,13 @@ OP_TRAIT_REG(GaussianRNG, GaussianRNG)
}} // gaussian_rng

namespace { namespace roi_align {
auto apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
return opr::ROIAlign::make(inputs[0], inputs[1], op.param());
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
@@ -496,12 +503,13 @@ OP_TRAIT_REG(Eye, Eye)
}} // eye

namespace { namespace roi_pooling {
auto apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param());
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
.apply_on_var_node(apply_on_var_node)
@@ -620,11 +628,11 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param());
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output();
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // svd

} // namespace mgb::imperative
} // namespace mgb::imperative

Loading…
Cancel
Save