GitOrigin-RevId: f666f6d700
tags/v0.5.0
@@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( | |||||
auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); | auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); | ||||
mgb_assert(ins.second || ins.first->second.first, | mgb_assert(ins.second || ins.first->second.first, | ||||
"opr output already replaced"); | |||||
"opr output already replaced"); | |||||
// handle repeated call on the same opr | // handle repeated call on the same opr | ||||
ins.first->second.second = out1[i]; | ins.first->second.second = out1[i]; | ||||
on_var_replaced(out0[i], out1[i], nullptr); | on_var_replaced(out0[i], out1[i], nullptr); | ||||
@@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
/* ================ ConstVarPropogateBase ================ */ | /* ================ ConstVarPropogateBase ================ */ | ||||
ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( | |||||
ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( | |||||
OperatorNodeBase *opr) { | OperatorNodeBase *opr) { | ||||
using ProfFlag = OperatorNodeBase::NodeProp::Flag; | using ProfFlag = OperatorNodeBase::NodeProp::Flag; | ||||
auto &&info = m_oprinfo[opr]; | auto &&info = m_oprinfo[opr]; | ||||
@@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( | |||||
#endif | #endif | ||||
info.max_size = max_input_size; | info.max_size = max_input_size; | ||||
info.is_const = true; | info.is_const = true; | ||||
on_midconst_opr(opr, max_input_size); | |||||
} | } | ||||
return make_ret(); | return make_ret(); | ||||
} | } | ||||
@@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const { | |||||
/* ================ ParamFusePass ================ */ | /* ================ ParamFusePass ================ */ | ||||
class ParamFusePass::ConstVarPropogateWithSizeCheck final: | |||||
public ConstVarPropogateBase | |||||
{ | |||||
public: | |||||
//! rewrite a var; reader == nullptr means needed by endpoint | |||||
using VarRewriter = std::function< | |||||
void(VarNode *var, OperatorNodeBase *reader)>; | |||||
ConstVarPropogateWithSizeCheck( | |||||
const ParamFusePass &pf, OptState &opt_state, | |||||
const VarRewriter &rewriter): | |||||
ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM}, | |||||
m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter} | |||||
{ | |||||
} | |||||
private: | |||||
const ParamFusePass &m_owner; | |||||
OptState &m_opt_state; | |||||
VarRewriter m_rewriter; | |||||
void on_midconst_opr( | |||||
OperatorNodeBase *opr, size_t max_src_size) override { | |||||
for (auto var: opr->output()) { | |||||
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) | |||||
continue; | |||||
auto osize = var_mem_size(var); | |||||
if (osize >= max_src_size && | |||||
osize - max_src_size > m_owner.m_param_grow_limit) { | |||||
return; | |||||
} | |||||
// const oprs should be evaluated when output is used by another | |||||
// non-const opr or output is needed by the user | |||||
if (m_opt_state.graph().endpoint_contain(var)) { | |||||
m_rewriter(var, nullptr); | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
/*! | /*! | ||||
* \brief get name for new param | * \brief get name for new param | ||||
*/ | */ | ||||
@@ -565,9 +521,15 @@ const char* ParamFusePass::name() const { | |||||
void ParamFusePass::apply(OptState &state) const { | void ParamFusePass::apply(OptState &state) const { | ||||
auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
auto cg = state.graph().comp_graph(); | auto cg = state.graph().comp_graph(); | ||||
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; | |||||
state.graph().iter([&cvprop](OperatorNodeBase *opr) { | |||||
cvprop.add_opr(opr); | |||||
}); | |||||
ThinHashSet<VarNode*> processed_var; | ThinHashSet<VarNode*> processed_var; | ||||
VarNamer var_namer; | VarNamer var_namer; | ||||
// reader: null if used as endvar | // reader: null if used as endvar | ||||
auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) { | auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) { | ||||
if (!processed_var.insert(var).second) | if (!processed_var.insert(var).second) | ||||
@@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const { | |||||
rewriter.replace_var(var, new_var.node(), log.c_str()); | rewriter.replace_var(var, new_var.node(), log.c_str()); | ||||
}; | }; | ||||
ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var}; | |||||
auto on_opr = [&](OperatorNodeBase *opr) { | |||||
auto add_ret = cvprop.add_opr(opr); | |||||
auto replace_opr = [&](OperatorNodeBase* opr) { | |||||
auto add_ret = cvprop.opr_rst(opr); | |||||
if (!add_ret.all_const_inp && add_ret.has_midconst_inp) { | if (!add_ret.all_const_inp && add_ret.has_midconst_inp) { | ||||
for (auto i: opr->input()) { | for (auto i: opr->input()) { | ||||
if (cvprop.is_midconst(i)) { | if (cvprop.is_midconst(i)) { | ||||
@@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const { | |||||
} | } | ||||
} | } | ||||
rewriter.auto_replace_outputs(opr); | rewriter.auto_replace_outputs(opr); | ||||
//! we should deal with midconst var after auto_replace_outputs, as | |||||
//! on_midconst_opr will replace the endpoint output which may cause | |||||
//! double replace. | |||||
if (add_ret.all_const_inp) { | |||||
for (auto var : opr->output()) { | |||||
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) | |||||
continue; | |||||
auto osize = ConstVarPropogate::var_mem_size(var); | |||||
if (osize >= cvprop.max_size(opr) && | |||||
osize - cvprop.max_size(opr) > m_param_grow_limit) { | |||||
return; | |||||
} | |||||
// const oprs should be evaluated when output is used by another | |||||
// non-const opr or output is needed by the user | |||||
if (state.graph().endpoint_contain(var)) { | |||||
replace_single_var(var, nullptr); | |||||
} | |||||
} | |||||
} | |||||
}; | }; | ||||
state.graph().iter(on_opr); | |||||
state.graph().iter(replace_opr); | |||||
rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
} | } | ||||
@@ -490,28 +490,17 @@ namespace gopt { | |||||
* Usually you would want to use ConstVarPropogate, and this base class | * Usually you would want to use ConstVarPropogate, and this base class | ||||
* exists to avoid virtual dtor while allowing polymorphism. | * exists to avoid virtual dtor while allowing polymorphism. | ||||
*/ | */ | ||||
class ConstVarPropogateBase { | |||||
protected: | |||||
~ConstVarPropogateBase() = default; | |||||
//! memory usage of a var | |||||
static size_t var_mem_size(VarNode *var) { | |||||
return var->dtype().size(var->shape().total_nr_elems()); | |||||
} | |||||
//! called after a const but non-source opr is visited | |||||
virtual void on_midconst_opr( | |||||
OperatorNodeBase *opr, size_t max_src_size) { | |||||
MGB_MARK_USED_VAR(opr); | |||||
MGB_MARK_USED_VAR(max_src_size); | |||||
} | |||||
class ConstVarPropogate{ | |||||
public: | public: | ||||
explicit ConstVarPropogateBase(ConstVarType const_var_type): | |||||
explicit ConstVarPropogate(ConstVarType const_var_type): | |||||
m_const_var_type{const_var_type} | m_const_var_type{const_var_type} | ||||
{ | { | ||||
} | } | ||||
ConstVarPropogate() = default; | |||||
~ConstVarPropogate() = default; | |||||
//! note that both attrs would be false if opr is impure or it is | //! note that both attrs would be false if opr is impure or it is | ||||
//! not allowed to be replaced | //! not allowed to be replaced | ||||
struct AddOprResult { | struct AddOprResult { | ||||
@@ -527,12 +516,19 @@ namespace gopt { | |||||
AddOprResult add_opr(OperatorNodeBase *opr); | AddOprResult add_opr(OperatorNodeBase *opr); | ||||
const AddOprResult& opr_rst(OperatorNodeBase *opr) const { | |||||
return m_oprinfo.at(opr).result; | |||||
} | |||||
bool is_const(OperatorNodeBase *opr) const { | bool is_const(OperatorNodeBase *opr) const { | ||||
return m_oprinfo.at(opr).is_const; | return m_oprinfo.at(opr).is_const; | ||||
} | } | ||||
bool is_const(VarNode *var) const { | bool is_const(VarNode *var) const { | ||||
return is_const(var->owner_opr()); | return is_const(var->owner_opr()); | ||||
} | } | ||||
size_t max_size(OperatorNodeBase *opr) const { | |||||
return m_oprinfo.at(opr).max_size; | |||||
} | |||||
//! whether a var is produced by non-source const opr | //! whether a var is produced by non-source const opr | ||||
bool is_midconst(OperatorNodeBase *opr) const { | bool is_midconst(OperatorNodeBase *opr) const { | ||||
@@ -543,6 +539,11 @@ namespace gopt { | |||||
return is_midconst(var->owner_opr()); | return is_midconst(var->owner_opr()); | ||||
} | } | ||||
//! memory usage of a var | |||||
static size_t var_mem_size(VarNode *var) { | |||||
return var->dtype().size(var->shape().total_nr_elems()); | |||||
} | |||||
private: | private: | ||||
struct OprInfo { | struct OprInfo { | ||||
bool processed = false, is_const = false; | bool processed = false, is_const = false; | ||||
@@ -556,11 +557,6 @@ namespace gopt { | |||||
}; | }; | ||||
class ConstVarPropogate final: public ConstVarPropogateBase { | |||||
public: | |||||
using ConstVarPropogateBase::ConstVarPropogateBase; | |||||
}; | |||||
} // namespace gopt | } // namespace gopt | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H, | |||||
#endif | #endif | ||||
} // namespace | } // namespace | ||||
TEST(TestGoptInference, ParamFuseConstEndPoint) { | |||||
constexpr size_t SIZE = 23; | |||||
HostTensorGenerator<> gen; | |||||
auto host_x = gen({SIZE}), host_y = gen({1}), host_p = gen({1}); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto x = opr::SharedDeviceTensor::make(*graph, *host_x), | |||||
y = opr::SharedDeviceTensor::make(*graph, *host_y), | |||||
p = opr::Host2DeviceCopy::make(*graph, host_p), | |||||
q = p + x, | |||||
a = y + 3, | |||||
z0 = a + q, | |||||
z1 = a + 4; | |||||
HostTensorND host_z0, host_z1; | |||||
SymbolVar z0_1, z1_1; | |||||
unpack_vector( | |||||
gopt::GraphOptimizer{}. | |||||
add_pass<gopt::ParamFusePass>(). | |||||
apply({{z1, z0}}).endpoint_vars(), | |||||
z1_1, z0_1); | |||||
auto func = graph->compile({make_callback_copy(z0_1, host_z0), | |||||
make_callback_copy(z1_1, host_z1)}); | |||||
func->to_json()->writeto_fpath( | |||||
output_file("TestGoptInference.ParamFuseEndPoint.json")); | |||||
func->execute(); | |||||
int nr_opr = 0; | |||||
func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); | |||||
ASSERT_EQ(8, nr_opr); | |||||
auto px = host_x->ptr<float>(), pz0 = host_z0.ptr<float>(); | |||||
auto yv = host_y->ptr<float>()[0], pv = host_p->ptr<float>()[0], | |||||
pz1 = host_z1.ptr<float>()[0]; | |||||
for (size_t i = 0; i < SIZE; ++ i) { | |||||
MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]); | |||||
} | |||||
MGB_ASSERT_FLOAT_EQ(yv + 7, pz1); | |||||
} | |||||
TEST(TestGoptInference, ParamFuse) { | TEST(TestGoptInference, ParamFuse) { | ||||
constexpr size_t SIZE = 23; | constexpr size_t SIZE = 23; | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
@@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) { | |||||
func->execute(); | func->execute(); | ||||
int nr_opr = 0; | int nr_opr = 0; | ||||
func->iter_opr_seq([&](cg::OperatorNodeBase*op) {++ nr_opr; return true; }); | |||||
func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); | |||||
ASSERT_EQ(6, nr_opr); | ASSERT_EQ(6, nr_opr); | ||||
auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), | auto px = host_x->ptr<float>(), pz = host_z.ptr<float>(), | ||||