|
|
@@ -635,15 +635,19 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { |
|
|
|
|
|
|
|
auto on_opr = [&](OperatorNodeBase* opr) { |
|
|
|
if (auto tc0 = try_cast_as_op<opr::TypeCvt>(opr)) { |
|
|
|
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(tc0->input(0))) { |
|
|
|
auto inp0 = rewriter.get_var(tc0->input(0)); |
|
|
|
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(inp0)) { |
|
|
|
if (should_remove(tc0->param(), tc1->param())) { |
|
|
|
auto inp1 = tc1->input(0); |
|
|
|
mgb_assert(!rewriter.has_manual_replace(inp1)); |
|
|
|
// TypeCvt returns the input var if its dtype is already |
|
|
|
// dest_type |
|
|
|
auto fold = opr::TypeCvt::make(tc1->input(0), tc0->param()); |
|
|
|
auto fold = opr::TypeCvt::make(inp1, tc0->param()); |
|
|
|
rewriter.replace_var( |
|
|
|
tc0->output(0), fold.node(), |
|
|
|
mgb_cstr_log("cvt_b(cvt_a(x)) -> cvt_b(x)")); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
rewriter.auto_replace_outputs(opr); |
|
|
|