|
|
@@ -625,8 +625,14 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { |
|
|
|
auto rewriter = state.graph().make_rewriter(); |
|
|
|
VarNodeArray new_inp_cache; |
|
|
|
|
|
|
|
auto on_opr = [this, &rewriter, &new_inp_cache, |
|
|
|
&state](OperatorNodeBase* opr) { |
|
|
|
// record original output dtype |
|
|
|
const SymbolVarArray& vars = state.graph().endpoint_vars(); |
|
|
|
std::vector<DType> dtypes; |
|
|
|
for (size_t i = 0; i < vars.size(); i++) { |
|
|
|
dtypes.push_back(vars[i].node()->dtype()); |
|
|
|
} |
|
|
|
|
|
|
|
auto on_opr = [this, &rewriter, &new_inp_cache](OperatorNodeBase* opr) { |
|
|
|
auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); |
|
|
|
if (it != m_opr_replace_func.end()) { |
|
|
|
auto&& new_inp = new_inp_cache; |
|
|
@@ -642,40 +648,32 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { |
|
|
|
"bad opr replace: src=%s{%s} dst=%s{%s}", opr->cname(), |
|
|
|
opr->dyn_typeinfo()->name, new_opr->cname(), |
|
|
|
new_opr->dyn_typeinfo()->name); |
|
|
|
//! change the output type if it's the endpoint |
|
|
|
for (size_t i = 0; i < origin_out.size(); i++) { |
|
|
|
if (state.graph().endpoint_contain(origin_out[i]) && |
|
|
|
origin_out[i]->dtype().enumv() != |
|
|
|
cur_out[i]->dtype().enumv()) { |
|
|
|
rewriter.replace_var( |
|
|
|
origin_out[i], |
|
|
|
opr::TypeCvt::make(cur_out[i], |
|
|
|
origin_out[i]->dtype()) |
|
|
|
.node(), |
|
|
|
nullptr); |
|
|
|
} else { |
|
|
|
rewriter.replace_var(origin_out[i], cur_out[i], nullptr); |
|
|
|
} |
|
|
|
rewriter.replace_var(origin_out[i], cur_out[i], nullptr); |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto new_opr = rewriter.auto_replace_outputs(opr); |
|
|
|
auto&& out = opr->output(); |
|
|
|
auto&& new_out = new_opr->output(); |
|
|
|
for (size_t i = 0; i < out.size(); i++) { |
|
|
|
if (state.graph().endpoint_contain(out[i]) && |
|
|
|
new_out[i]->dtype().enumv() != out[i]->dtype().enumv()) { |
|
|
|
rewriter.replace_var( |
|
|
|
new_out[i], |
|
|
|
opr::TypeCvt::make(new_out[i], |
|
|
|
out[i]->dtype()) |
|
|
|
.node(), |
|
|
|
nullptr); |
|
|
|
} |
|
|
|
} |
|
|
|
rewriter.auto_replace_outputs(opr); |
|
|
|
} |
|
|
|
}; |
|
|
|
state.graph().iter(on_opr); |
|
|
|
rewriter.apply_inplace(); |
|
|
|
|
|
|
|
// recover output dtype |
|
|
|
rewriter = state.graph().make_rewriter(); |
|
|
|
const SymbolVarArray& endpoints = state.graph().endpoint_vars(); |
|
|
|
auto replace_output = [&]() { |
|
|
|
for (size_t i = 0; i < endpoints.size(); i++) { |
|
|
|
VarNode* var = endpoints[i].node(); |
|
|
|
if (var->dtype().enumv() != dtypes[i].enumv()) { |
|
|
|
auto new_var = opr::TypeCvt::make(var, dtypes[i]).node(); |
|
|
|
rewriter.replace_var(var, new_var, nullptr); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
mgb_assert(endpoints.size() > 0); |
|
|
|
auto opr = endpoints[0].node()->owner_opr(); |
|
|
|
state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); |
|
|
|
rewriter.apply_inplace(); |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( |
|
|
|