Browse Source

fix(mgb/gopt): fix remove redundant typecvt pass

GitOrigin-RevId: 6a7957e362
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
ccc95ad2a9
2 changed files with 12 additions and 2 deletions
  1. +6
    -2
      src/gopt/impl/misc.cpp
  2. +6
    -0
      src/gopt/test/misc.cpp

+ 6
- 2
src/gopt/impl/misc.cpp View File

@@ -635,15 +635,19 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {

auto on_opr = [&](OperatorNodeBase* opr) {
if (auto tc0 = try_cast_as_op<opr::TypeCvt>(opr)) {
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(tc0->input(0))) {
auto inp0 = rewriter.get_var(tc0->input(0));
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(inp0)) {
if (should_remove(tc0->param(), tc1->param())) {
auto inp1 = tc1->input(0);
mgb_assert(!rewriter.has_manual_replace(inp1));
// TypeCvt returns the input var if its dtype is already
// dest_type
auto fold = opr::TypeCvt::make(tc1->input(0), tc0->param());
auto fold = opr::TypeCvt::make(inp1, tc0->param());
rewriter.replace_var(
tc0->output(0), fold.node(),
mgb_cstr_log("cvt_b(cvt_a(x)) -> cvt_b(x)"));
}
return;
}
}
rewriter.auto_replace_outputs(opr);


+ 6
- 0
src/gopt/test/misc.cpp View File

@@ -395,6 +395,12 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
check(x_fp16, x_fp16_fp32_fp16);
#endif

auto x_i32 = opr::TypeCvt::make(x, dtype::Int32());
auto x_i32_i16 = opr::TypeCvt::make(x_i32, dtype::Int16());
auto x_i32_i16_i8 = opr::TypeCvt::make(x_i32_i16, dtype::Int8());
auto x_i8 = opr::TypeCvt::make(x, dtype::Int8());
check(x_i8, x_i32_i16_i8);

auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(0.1f));
auto x_q8_fp32 = opr::TypeCvt::make(x_q8, dtype::Float32());
auto x_q8_fp32_q8 = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(0.1f));


Loading…
Cancel
Save