Browse Source

feat(gopt/inference): allow Float32 output dtype in EnableNCHW64Pass

GitOrigin-RevId: 1891efb76f
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
6d686ff26f
1 changed files with 21 additions and 9 deletions
  1. +21
    -9
      src/gopt/impl/tensor_reformat.cpp

+ 21
- 9
src/gopt/impl/tensor_reformat.cpp View File

@@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() {
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8;
mgb_assert(opr->output().size() > 0);
bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3) {
auto dtype_expect = dst_float ? DTypeEnum::Float32
: DTypeEnum::QuantizedS32;
check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect;
}
if (opr->input().size() >= 4) {
check_dtype &= new_inp[3]->dtype().enumv() ==
opr->output(0)->dtype().enumv();
}
if (!check_dtype)
return nullptr;

size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0;
@@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() {
}
}
};

for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
// do not format bias and z when dst_float is true
bool skip = dst_float && i >= 2;
if (!skip) inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
auto ret = make_new_conv(
inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
if (!dst_float)
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret;
};



Loading…
Cancel
Save