|
|
@@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
bool check_dtype = |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= |
|
|
|
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
mgb_assert(opr->output().size() > 0); |
|
|
|
bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (opr->input().size() >= 3) { |
|
|
|
auto dtype_expect = dst_float ? DTypeEnum::Float32 |
|
|
|
: DTypeEnum::QuantizedS32; |
|
|
|
check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; |
|
|
|
} |
|
|
|
if (opr->input().size() >= 4) { |
|
|
|
check_dtype &= new_inp[3]->dtype().enumv() == |
|
|
|
opr->output(0)->dtype().enumv(); |
|
|
|
} |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
|
|
|
|
size_t out_channels = opr->input(1)->shape()[0]; |
|
|
|
size_t in_channels = opr->input(1)->shape()[1]; |
|
|
|
bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0; |
|
|
@@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
// do not format bias and z when dst_float is true |
|
|
|
bool skip = dst_float && i >= 2; |
|
|
|
if (!skip) inps[i] = process(i); |
|
|
|
} |
|
|
|
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); |
|
|
|
auto ret = make_new_conv( |
|
|
|
inps, &conv_bias, |
|
|
|
dst_float ? Format::NCHW4_NCHW : Format::NCHW4); |
|
|
|
if (!dst_float) |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|