Browse Source

fix(core): make the semantics of instance id clear and correct

GitOrigin-RevId: 2232195c50
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
09d2b7c3fe
7 changed files with 74 additions and 21 deletions
  1. +2
    -2
      src/core/impl/graph/operator_node.cpp
  2. +7
    -4
      src/core/impl/graph/seq_sublinear_memory.cpp
  3. +24
    -11
      src/core/include/megbrain/graph/operator_node.h
  4. +37
    -0
      src/core/test/graph/misc.cpp
  5. +1
    -1
      src/gopt/impl/misc.cpp
  6. +1
    -1
      src/gopt/test/misc.cpp
  7. +2
    -2
      src/jit/impl/internal_graph.cpp

+ 2
- 2
src/core/impl/graph/operator_node.cpp View File

@@ -485,14 +485,14 @@ OperatorNodeConfig& OperatorNodeConfig::comp_node_arr(


size_t OperatorNodeConfig::hash() const { size_t OperatorNodeConfig::hash() const {
return hash_pair_combine( return hash_pair_combine(
hash_pair_combine(mgb::hash(m_instance_id), mgb::hash(m_comp_node)),
hash_pair_combine(m_instance_id_hashed, mgb::hash(m_comp_node)),
mgb::hash(m_output_dtype.handle())); mgb::hash(m_output_dtype.handle()));
} }


bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const { bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const {
auto &&rhs = static_cast<const OperatorNodeConfig&>(rhs_); auto &&rhs = static_cast<const OperatorNodeConfig&>(rhs_);
return m_comp_node == rhs.m_comp_node && return m_comp_node == rhs.m_comp_node &&
m_instance_id == rhs.m_instance_id &&
m_instance_id_hashed == rhs.m_instance_id_hashed &&
m_output_dtype == rhs.m_output_dtype; m_output_dtype == rhs.m_output_dtype;
} }




+ 7
- 4
src/core/impl/graph/seq_sublinear_memory.cpp View File

@@ -1225,14 +1225,17 @@ bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) {
OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs(
OperatorNodeBase* opr, bool recomp) { OperatorNodeBase* opr, bool recomp) {
auto config = opr->config(); auto config = opr->config();
// set operator instance id to bybass the shallow copy's cache if
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding. // it's a dup-opr-copying due to discarding.
// Don't set instance id(nullptr) if it's a recomp-opr-copying, because:
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed // 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params // 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them. // and config, we use instance id to differentiate them.
config.name(opr->name() + (recomp ? ":recomp" : ":dup"))
.instance_id(recomp ? nullptr : this);
config.name(opr->name() + (recomp ? ":recomp" : ":dup"));
if (!recomp) {
config.update_instance_id(this);
}


// Note: if all outputs of op were placed on the same comp_node, since its // Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has // stream maybe changed during seq_comp_node_opt, output's comp_node has


+ 24
- 11
src/core/include/megbrain/graph/operator_node.h View File

@@ -70,24 +70,36 @@ class OperatorNodeConfig final: public Hashable {
} }


/*! /*!
* \brief set instance id
* \brief update instance ID
* *
* Instance id is used to differentiate multiple instances of the same
* operator (with same inputs, params and config), so the deduplication
* system can be bypassed.
* Instance ID is a hashed value used to differentiate multiple
* instances of the same operator (with same inputs, params and
* config), so the deduplication system can be bypassed.
* *
* Currently only used for sublinear memory optimization.
* This method always updates underlying instance_id.
*/ */
OperatorNodeConfig& instance_id(const void *id) {
m_instance_id = id;
template<typename T>
OperatorNodeConfig& update_instance_id(const T& p) {
static_assert(std::is_pointer<T>::value,
"update_instance_id can only accept a pointer");
m_instance_id_hashed = hash_pair_combine(
m_instance_id_hashed, mgb::hash(p));
return *this; return *this;
} }


/*! /*!
* \brief get current instance ID
* \brief reset instance ID to the initial value
*/ */
const void* instance_id() const {
return m_instance_id;
OperatorNodeConfig& reset_instance_id() {
m_instance_id_hashed = sm_initial_instance_id;
return *this;
}

/*!
* \brief get current hashed instance ID
*/
size_t instance_id() const {
return m_instance_id_hashed;
} }


/*! /*!
@@ -133,9 +145,10 @@ class OperatorNodeConfig final: public Hashable {
bool is_same_st(const Hashable &rhs) const override; bool is_same_st(const Hashable &rhs) const override;


private: private:
static constexpr size_t sm_initial_instance_id = 1333331;
Maybe<std::string> m_name; Maybe<std::string> m_name;
CompNodeArray m_comp_node; CompNodeArray m_comp_node;
const void *m_instance_id = nullptr;
size_t m_instance_id_hashed = sm_initial_instance_id;
DType m_output_dtype; DType m_output_dtype;
}; };




+ 37
- 0
src/core/test/graph/misc.cpp View File

@@ -1777,4 +1777,41 @@ TEST(TestGraph, In2OutOpStreamPropagate) {
} }
} }


TEST(TestGraph, OperatorNodeConfigInstanceID) {
OperatorNodeConfig config0, config1;
void *p0 = &config0, *p1 = &config1;
{ // set and reset
ASSERT_EQ(config0.instance_id(), config1.instance_id());
config0.update_instance_id(p0);
ASSERT_NE(config0.instance_id(), config1.instance_id());
config0.reset_instance_id();
ASSERT_EQ(config0.instance_id(), config1.instance_id());
}
{ // set to the same pointer
config0.reset_instance_id();
config0.update_instance_id(p1);
config1.reset_instance_id();
config1.update_instance_id(p1);
ASSERT_EQ(config0.instance_id(), config1.instance_id());
}
{ // check update semantics
config0.reset_instance_id();
config0.update_instance_id(p0);
config1.reset_instance_id();
config1.update_instance_id(p1);
ASSERT_NE(config0.instance_id(), config1.instance_id());
config0.update_instance_id(p1);
ASSERT_NE(config0.instance_id(), config1.instance_id());
}
{ // set in different order
config0.reset_instance_id();
config0.update_instance_id(p1);
config0.update_instance_id(p0);
config1.reset_instance_id();
config1.update_instance_id(p0);
config1.update_instance_id(p1);
ASSERT_NE(config0.instance_id(), config1.instance_id());
}
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
src/gopt/impl/misc.cpp View File

@@ -361,7 +361,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const {
size_t prev_step = iter->second; size_t prev_step = iter->second;
if (step - prev_step > m_threshold) { if (step - prev_step > m_threshold) {
OperatorNodeConfig config = opr->config(); OperatorNodeConfig config = opr->config();
config.instance_id(opr);
config.update_instance_id(opr);
opt.call_with_opr(typecvt, [&]{ opt.call_with_opr(typecvt, [&]{
auto new_typecvt = auto new_typecvt =
opr::TypeCvt::make( opr::TypeCvt::make(


+ 1
- 1
src/gopt/test/misc.cpp View File

@@ -261,7 +261,7 @@ TEST_PASS(RecompTypeCvtPass, Basic) {
} }
auto for_pass = f + x_fp32; auto for_pass = f + x_fp32;
OperatorNodeConfig config = x_fp32.node()->owner_opr()->config(); OperatorNodeConfig config = x_fp32.node()->owner_opr()->config();
config.instance_id(for_pass.node()->owner_opr());
config.update_instance_id(for_pass.node()->owner_opr());
auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(), auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(),
config); config);




+ 2
- 2
src/jit/impl/internal_graph.cpp View File

@@ -92,8 +92,8 @@ VarNode* InternalGraphGenerator::replace_graph_by_placeholder() {
auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
OperatorNodeConfig config = opr->config(); OperatorNodeConfig config = opr->config();
// remove instance_id.
config.instance_id(nullptr);
// reset instance_id.
config.reset_instance_id();
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) { if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{cpu_default}; HostTensorND hval{cpu_default};
hval.copy_from(imm->value()).sync(); hval.copy_from(imm->value()).sync();


Loading…
Cancel
Save