|
|
@@ -1670,10 +1670,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
}; |
|
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
bool has_inp_changed = false; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
@@ -2827,18 +2823,18 @@ MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, |
|
|
|
public: |
|
|
|
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat inp_format() const { |
|
|
|
return m_inp_format; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat out_format() const { |
|
|
|
return m_out_format; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private: |
|
|
|
void init_output_static_infer_desc() override; |
|
|
|
void scn_do_execute() override; |
|
|
@@ -3262,7 +3258,7 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); |
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp); |
|
|
|
auto y2 = opr::TypeCvt::make(y1, dtype::Float32()); |
|
|
|
auto y2 = opr::TypeCvt::make(y1, dtype::Float32()); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
|
|
|
|