From 09d2b7c3fe25d72f4dd70452ed671f4b49cf4b7a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 9 Jun 2020 19:16:57 +0800 Subject: [PATCH] fix(core): make the semantics of instance id clear and correct GitOrigin-RevId: 2232195c50e02482c8737ff70ab9d1c709bda6ee --- src/core/impl/graph/operator_node.cpp | 4 +-- src/core/impl/graph/seq_sublinear_memory.cpp | 11 +++++--- src/core/include/megbrain/graph/operator_node.h | 35 +++++++++++++++-------- src/core/test/graph/misc.cpp | 37 +++++++++++++++++++++++++ src/gopt/impl/misc.cpp | 2 +- src/gopt/test/misc.cpp | 2 +- src/jit/impl/internal_graph.cpp | 4 +-- 7 files changed, 74 insertions(+), 21 deletions(-) diff --git a/src/core/impl/graph/operator_node.cpp b/src/core/impl/graph/operator_node.cpp index 1872e2fa..3be33fcb 100644 --- a/src/core/impl/graph/operator_node.cpp +++ b/src/core/impl/graph/operator_node.cpp @@ -485,14 +485,14 @@ OperatorNodeConfig& OperatorNodeConfig::comp_node_arr( size_t OperatorNodeConfig::hash() const { return hash_pair_combine( - hash_pair_combine(mgb::hash(m_instance_id), mgb::hash(m_comp_node)), + hash_pair_combine(m_instance_id_hashed, mgb::hash(m_comp_node)), mgb::hash(m_output_dtype.handle())); } bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const { auto &&rhs = static_cast(rhs_); return m_comp_node == rhs.m_comp_node && - m_instance_id == rhs.m_instance_id && + m_instance_id_hashed == rhs.m_instance_id_hashed && m_output_dtype == rhs.m_output_dtype; } diff --git a/src/core/impl/graph/seq_sublinear_memory.cpp b/src/core/impl/graph/seq_sublinear_memory.cpp index 16fc4d2f..59750d80 100644 --- a/src/core/impl/graph/seq_sublinear_memory.cpp +++ b/src/core/impl/graph/seq_sublinear_memory.cpp @@ -1225,14 +1225,17 @@ bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( OperatorNodeBase* opr, bool recomp) { auto config = opr->config(); - // set operator instance id to bybass the shallow copy's cache if + // update operator instance id to bybass the shallow copy's cache if // it's a dup-opr-copying due to discarding. - // Don't set instance id(nullptr) if it's a recomp-opr-copying, because: + // Don't update instance id by `this` pointer if it's a recomp-opr-copying + // because: // 0) recomp-opr would be copied iff its input vars is changed // 1) some pair of recomp-opr and dup-opr have the same inputs, params // and config, we use instance id to differentiate them. - config.name(opr->name() + (recomp ? ":recomp" : ":dup")) - .instance_id(recomp ? nullptr : this); + config.name(opr->name() + (recomp ? ":recomp" : ":dup")); + if (!recomp) { + config.update_instance_id(this); + } // Note: if all outputs of op were placed on the same comp_node, since its // stream maybe changed during seq_comp_node_opt, output's comp_node has diff --git a/src/core/include/megbrain/graph/operator_node.h b/src/core/include/megbrain/graph/operator_node.h index e8c8c43e..73fa06b7 100644 --- a/src/core/include/megbrain/graph/operator_node.h +++ b/src/core/include/megbrain/graph/operator_node.h @@ -70,24 +70,36 @@ class OperatorNodeConfig final: public Hashable { } /*! - * \brief set instance id + * \brief update instance ID * - * Instance id is used to differentiate multiple instances of the same - * operator (with same inputs, params and config), so the deduplication - * system can be bypassed. + * Instance ID is a hashed value used to differentiate multiple + * instances of the same operator (with same inputs, params and + * config), so the deduplication system can be bypassed. * - * Currently only used for sublinear memory optimization. + * This method always updates underlying instance_id. */ - OperatorNodeConfig& instance_id(const void *id) { - m_instance_id = id; + template + OperatorNodeConfig& update_instance_id(const T& p) { + static_assert(std::is_pointer::value, + "update_instance_id can only accept a pointer"); + m_instance_id_hashed = hash_pair_combine( + m_instance_id_hashed, mgb::hash(p)); return *this; } /*! - * \brief get current instance ID + * \brief reset instance ID to the initial value */ - const void* instance_id() const { - return m_instance_id; + OperatorNodeConfig& reset_instance_id() { + m_instance_id_hashed = sm_initial_instance_id; + return *this; + } + + /*! + * \brief get current hashed instance ID + */ + size_t instance_id() const { + return m_instance_id_hashed; } /*! @@ -133,9 +145,10 @@ class OperatorNodeConfig final: public Hashable { bool is_same_st(const Hashable &rhs) const override; private: + static constexpr size_t sm_initial_instance_id = 1333331; Maybe m_name; CompNodeArray m_comp_node; - const void *m_instance_id = nullptr; + size_t m_instance_id_hashed = sm_initial_instance_id; DType m_output_dtype; }; diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index 54f0567c..e1ce41fd 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -1777,4 +1777,41 @@ TEST(TestGraph, In2OutOpStreamPropagate) { } } +TEST(TestGraph, OperatorNodeConfigInstanceID) { + OperatorNodeConfig config0, config1; + void *p0 = &config0, *p1 = &config1; + { // set and reset + ASSERT_EQ(config0.instance_id(), config1.instance_id()); + config0.update_instance_id(p0); + ASSERT_NE(config0.instance_id(), config1.instance_id()); + config0.reset_instance_id(); + ASSERT_EQ(config0.instance_id(), config1.instance_id()); + } + { // set to the same pointer + config0.reset_instance_id(); + config0.update_instance_id(p1); + config1.reset_instance_id(); + config1.update_instance_id(p1); + ASSERT_EQ(config0.instance_id(), config1.instance_id()); + } + { // check update semantics + config0.reset_instance_id(); + config0.update_instance_id(p0); + config1.reset_instance_id(); + config1.update_instance_id(p1); + ASSERT_NE(config0.instance_id(), config1.instance_id()); + config0.update_instance_id(p1); + ASSERT_NE(config0.instance_id(), config1.instance_id()); + } + { // set in different order + config0.reset_instance_id(); + config0.update_instance_id(p1); + config0.update_instance_id(p0); + config1.reset_instance_id(); + config1.update_instance_id(p0); + config1.update_instance_id(p1); + ASSERT_NE(config0.instance_id(), config1.instance_id()); + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index f5dce790..136b0227 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -361,7 +361,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const { size_t prev_step = iter->second; if (step - prev_step > m_threshold) { OperatorNodeConfig config = opr->config(); - config.instance_id(opr); + config.update_instance_id(opr); opt.call_with_opr(typecvt, [&]{ auto new_typecvt = opr::TypeCvt::make( diff --git a/src/gopt/test/misc.cpp b/src/gopt/test/misc.cpp index ba0637b5..9f7ef088 100644 --- a/src/gopt/test/misc.cpp +++ b/src/gopt/test/misc.cpp @@ -261,7 +261,7 @@ TEST_PASS(RecompTypeCvtPass, Basic) { } auto for_pass = f + x_fp32; OperatorNodeConfig config = x_fp32.node()->owner_opr()->config(); - config.instance_id(for_pass.node()->owner_opr()); + config.update_instance_id(for_pass.node()->owner_opr()); auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(), config); diff --git a/src/jit/impl/internal_graph.cpp b/src/jit/impl/internal_graph.cpp index 4c257170..e491b340 100644 --- a/src/jit/impl/internal_graph.cpp +++ b/src/jit/impl/internal_graph.cpp @@ -92,8 +92,8 @@ VarNode* InternalGraphGenerator::replace_graph_by_placeholder() { auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, const VarNodeArray& inputs) { OperatorNodeConfig config = opr->config(); - // remove instance_id. - config.instance_id(nullptr); + // reset instance_id. + config.reset_instance_id(); if (auto imm = gopt::try_cast_as_op(opr)) { HostTensorND hval{cpu_default}; hval.copy_from(imm->value()).sync();