GitOrigin-RevId: 2232195c50
tags/v0.5.0
@@ -485,14 +485,14 @@ OperatorNodeConfig& OperatorNodeConfig::comp_node_arr( | |||||
size_t OperatorNodeConfig::hash() const { | size_t OperatorNodeConfig::hash() const { | ||||
return hash_pair_combine( | return hash_pair_combine( | ||||
hash_pair_combine(mgb::hash(m_instance_id), mgb::hash(m_comp_node)), | |||||
hash_pair_combine(m_instance_id_hashed, mgb::hash(m_comp_node)), | |||||
mgb::hash(m_output_dtype.handle())); | mgb::hash(m_output_dtype.handle())); | ||||
} | } | ||||
bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const { | bool OperatorNodeConfig::is_same_st(const Hashable &rhs_) const { | ||||
auto &&rhs = static_cast<const OperatorNodeConfig&>(rhs_); | auto &&rhs = static_cast<const OperatorNodeConfig&>(rhs_); | ||||
return m_comp_node == rhs.m_comp_node && | return m_comp_node == rhs.m_comp_node && | ||||
m_instance_id == rhs.m_instance_id && | |||||
m_instance_id_hashed == rhs.m_instance_id_hashed && | |||||
m_output_dtype == rhs.m_output_dtype; | m_output_dtype == rhs.m_output_dtype; | ||||
} | } | ||||
@@ -1225,14 +1225,17 @@ bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { | |||||
OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( | ||||
OperatorNodeBase* opr, bool recomp) { | OperatorNodeBase* opr, bool recomp) { | ||||
auto config = opr->config(); | auto config = opr->config(); | ||||
// set operator instance id to bybass the shallow copy's cache if | |||||
// update operator instance id to bybass the shallow copy's cache if | |||||
// it's a dup-opr-copying due to discarding. | // it's a dup-opr-copying due to discarding. | ||||
// Don't set instance id(nullptr) if it's a recomp-opr-copying, because: | |||||
// Don't update instance id by `this` pointer if it's a recomp-opr-copying | |||||
// because: | |||||
// 0) recomp-opr would be copied iff its input vars is changed | // 0) recomp-opr would be copied iff its input vars is changed | ||||
// 1) some pair of recomp-opr and dup-opr have the same inputs, params | // 1) some pair of recomp-opr and dup-opr have the same inputs, params | ||||
// and config, we use instance id to differentiate them. | // and config, we use instance id to differentiate them. | ||||
config.name(opr->name() + (recomp ? ":recomp" : ":dup")) | |||||
.instance_id(recomp ? nullptr : this); | |||||
config.name(opr->name() + (recomp ? ":recomp" : ":dup")); | |||||
if (!recomp) { | |||||
config.update_instance_id(this); | |||||
} | |||||
// Note: if all outputs of op were placed on the same comp_node, since its | // Note: if all outputs of op were placed on the same comp_node, since its | ||||
// stream maybe changed during seq_comp_node_opt, output's comp_node has | // stream maybe changed during seq_comp_node_opt, output's comp_node has | ||||
@@ -70,24 +70,36 @@ class OperatorNodeConfig final: public Hashable { | |||||
} | } | ||||
/*! | /*! | ||||
* \brief set instance id | |||||
* \brief update instance ID | |||||
* | * | ||||
* Instance id is used to differentiate multiple instances of the same | |||||
* operator (with same inputs, params and config), so the deduplication | |||||
* system can be bypassed. | |||||
* Instance ID is a hashed value used to differentiate multiple | |||||
* instances of the same operator (with same inputs, params and | |||||
* config), so the deduplication system can be bypassed. | |||||
* | * | ||||
* Currently only used for sublinear memory optimization. | |||||
* This method always updates underlying instance_id. | |||||
*/ | */ | ||||
OperatorNodeConfig& instance_id(const void *id) { | |||||
m_instance_id = id; | |||||
template<typename T> | |||||
OperatorNodeConfig& update_instance_id(const T& p) { | |||||
static_assert(std::is_pointer<T>::value, | |||||
"update_instance_id can only accept a pointer"); | |||||
m_instance_id_hashed = hash_pair_combine( | |||||
m_instance_id_hashed, mgb::hash(p)); | |||||
return *this; | return *this; | ||||
} | } | ||||
/*! | /*! | ||||
* \brief get current instance ID | |||||
* \brief reset instance ID to the initial value | |||||
*/ | */ | ||||
const void* instance_id() const { | |||||
return m_instance_id; | |||||
OperatorNodeConfig& reset_instance_id() { | |||||
m_instance_id_hashed = sm_initial_instance_id; | |||||
return *this; | |||||
} | |||||
/*! | |||||
* \brief get current hashed instance ID | |||||
*/ | |||||
size_t instance_id() const { | |||||
return m_instance_id_hashed; | |||||
} | } | ||||
/*! | /*! | ||||
@@ -133,9 +145,10 @@ class OperatorNodeConfig final: public Hashable { | |||||
bool is_same_st(const Hashable &rhs) const override; | bool is_same_st(const Hashable &rhs) const override; | ||||
private: | private: | ||||
static constexpr size_t sm_initial_instance_id = 1333331; | |||||
Maybe<std::string> m_name; | Maybe<std::string> m_name; | ||||
CompNodeArray m_comp_node; | CompNodeArray m_comp_node; | ||||
const void *m_instance_id = nullptr; | |||||
size_t m_instance_id_hashed = sm_initial_instance_id; | |||||
DType m_output_dtype; | DType m_output_dtype; | ||||
}; | }; | ||||
@@ -1777,4 +1777,41 @@ TEST(TestGraph, In2OutOpStreamPropagate) { | |||||
} | } | ||||
} | } | ||||
TEST(TestGraph, OperatorNodeConfigInstanceID) { | |||||
OperatorNodeConfig config0, config1; | |||||
void *p0 = &config0, *p1 = &config1; | |||||
{ // set and reset | |||||
ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||||
config0.update_instance_id(p0); | |||||
ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||||
config0.reset_instance_id(); | |||||
ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||||
} | |||||
{ // set to the same pointer | |||||
config0.reset_instance_id(); | |||||
config0.update_instance_id(p1); | |||||
config1.reset_instance_id(); | |||||
config1.update_instance_id(p1); | |||||
ASSERT_EQ(config0.instance_id(), config1.instance_id()); | |||||
} | |||||
{ // check update semantics | |||||
config0.reset_instance_id(); | |||||
config0.update_instance_id(p0); | |||||
config1.reset_instance_id(); | |||||
config1.update_instance_id(p1); | |||||
ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||||
config0.update_instance_id(p1); | |||||
ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||||
} | |||||
{ // set in different order | |||||
config0.reset_instance_id(); | |||||
config0.update_instance_id(p1); | |||||
config0.update_instance_id(p0); | |||||
config1.reset_instance_id(); | |||||
config1.update_instance_id(p0); | |||||
config1.update_instance_id(p1); | |||||
ASSERT_NE(config0.instance_id(), config1.instance_id()); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -361,7 +361,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const { | |||||
size_t prev_step = iter->second; | size_t prev_step = iter->second; | ||||
if (step - prev_step > m_threshold) { | if (step - prev_step > m_threshold) { | ||||
OperatorNodeConfig config = opr->config(); | OperatorNodeConfig config = opr->config(); | ||||
config.instance_id(opr); | |||||
config.update_instance_id(opr); | |||||
opt.call_with_opr(typecvt, [&]{ | opt.call_with_opr(typecvt, [&]{ | ||||
auto new_typecvt = | auto new_typecvt = | ||||
opr::TypeCvt::make( | opr::TypeCvt::make( | ||||
@@ -261,7 +261,7 @@ TEST_PASS(RecompTypeCvtPass, Basic) { | |||||
} | } | ||||
auto for_pass = f + x_fp32; | auto for_pass = f + x_fp32; | ||||
OperatorNodeConfig config = x_fp32.node()->owner_opr()->config(); | OperatorNodeConfig config = x_fp32.node()->owner_opr()->config(); | ||||
config.instance_id(for_pass.node()->owner_opr()); | |||||
config.update_instance_id(for_pass.node()->owner_opr()); | |||||
auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(), | auto expected = f + opr::TypeCvt::make(sin_x, dtype::Float32(), | ||||
config); | config); | ||||
@@ -92,8 +92,8 @@ VarNode* InternalGraphGenerator::replace_graph_by_placeholder() { | |||||
auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | auto igraph_copy_opr_shallow = [cpu_default](OperatorNodeBase* opr, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
OperatorNodeConfig config = opr->config(); | OperatorNodeConfig config = opr->config(); | ||||
// remove instance_id. | |||||
config.instance_id(nullptr); | |||||
// reset instance_id. | |||||
config.reset_instance_id(); | |||||
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) { | if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) { | ||||
HostTensorND hval{cpu_default}; | HostTensorND hval{cpu_default}; | ||||
hval.copy_from(imm->value()).sync(); | hval.copy_from(imm->value()).sync(); | ||||