Browse Source

fix(mgb/gopt): fix opt pass elementwise operation shape issue at tranform to NCHW4

GitOrigin-RevId: c0c4e3f82e
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
51868533c8
2 changed files with 8 additions and 11 deletions
  1. +5
    -9
      src/gopt/impl/tensor_reformat.cpp
  2. +3
    -2
      src/gopt/test/inference.cpp

+ 5
- 9
src/gopt/impl/tensor_reformat.cpp View File

@@ -1670,10 +1670,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) {
@@ -2827,18 +2823,18 @@ MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat out_format() const {
return m_out_format;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
@@ -3262,7 +3258,7 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
auto y2 = opr::TypeCvt::make(y1, dtype::Float32());
auto y2 = opr::TypeCvt::make(y1, dtype::Float32());
return y2.node();
};



+ 3
- 2
src/gopt/test/inference.cpp View File

@@ -2910,7 +2910,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
conv1, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});

auto y = opr::TypeCvt::make(conv2, dtype::Float32());
auto conv2_fp32 = opr::TypeCvt::make(conv2, dtype::Float32());
auto y = conv2_fp32 + opr::TypeCvt::make(b2, dtype::Float32());

SymbolVar y_opt;
{
@@ -4076,7 +4077,7 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) {

auto y = opr::ConvBias::make(x, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
param.stride_h = param.stride_w = 1;
param.stride_h = param.stride_w = 1;
y = opr::ConvBias::make(y, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());


Loading…
Cancel
Save