|
|
@@ -206,7 +206,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
NCHW_TO_NCHW4_IC_SMALL_CONV) { |
|
|
|
if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0, |
|
|
|
"src shape %s", inp_shape.to_string().c_str()); |
|
|
|
} else { |
|
|
|
mgb_assert(layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType:: |
|
|
@@ -411,7 +412,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[4] = 32; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0); |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", inp_shape.to_string().c_str()); |
|
|
|
dst.ndim = 5; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[1] / 64; |
|
|
@@ -4191,12 +4192,12 @@ void PaddingChannelPass::apply(OptState& opt) const { |
|
|
|
VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, |
|
|
|
VarNode* orig_var) const { |
|
|
|
if (!orig_var->shape().eq_shape(new_var->shape())) { |
|
|
|
auto iter = m_opr_format_map.find(orig_var->owner_opr()); |
|
|
|
auto iter = m_opr_format_map.find(new_var->owner_opr()); |
|
|
|
mgb_assert(iter != m_opr_format_map.end(), |
|
|
|
"cannot find opr(type:%s,name:%s) information, related " |
|
|
|
"output var node(name:%s)", |
|
|
|
orig_var->owner_opr()->dyn_typeinfo()->name, |
|
|
|
orig_var->owner_opr()->cname(), orig_var->cname()); |
|
|
|
new_var->owner_opr()->dyn_typeinfo()->name, |
|
|
|
new_var->owner_opr()->cname(), new_var->cname()); |
|
|
|
const auto& fmt = iter->second; |
|
|
|
using LayoutType = RelayoutPlaceholder::LayoutType; |
|
|
|
LayoutType type; |
|
|
@@ -4253,20 +4254,70 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return new_conv.node(); |
|
|
|
} |
|
|
|
}; |
|
|
|
auto try_transform_to_nchw = |
|
|
|
[&format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= |
|
|
|
new_inp[3]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
auto inps = new_inp; |
|
|
|
auto process = [&](size_t i) -> VarNode* { |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
return inps[i]; |
|
|
|
} else { |
|
|
|
const auto& fmt = iter->second; |
|
|
|
if (fmt == Format::NCHW32) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW); |
|
|
|
return ovar.node(); |
|
|
|
} else if (fmt == Format::NCHW4) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); |
|
|
|
return ovar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(fmt == Format::NCHW64); |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW); |
|
|
|
return ovar.node(); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
} |
|
|
|
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
return ret->output()[0]; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_transform_to_nchw4 = |
|
|
|
[make_new_conv, &format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= |
|
|
|
opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
size_t out_channels = opr->input(1)->shape()[0]; |
|
|
@@ -4277,7 +4328,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
opr->cname(), out_channels, in_channels); |
|
|
|
auto inps = new_inp; |
|
|
|
auto process = [&](size_t i) -> VarNode* { |
|
|
|
auto iter = format_map.find(opr->input(i)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
@@ -4304,24 +4355,26 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW4)); |
|
|
|
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
return make_new_conv(inps, &conv_bias, Format::NCHW4); |
|
|
|
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_transform_to_nchw32 = |
|
|
|
[make_new_conv, &format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= |
|
|
|
opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
size_t out_channels = opr->input(1)->shape()[0]; |
|
|
@@ -4331,7 +4384,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return nullptr; |
|
|
|
auto inps = new_inp; |
|
|
|
auto process = [&](size_t i) -> VarNode* { |
|
|
|
auto iter = format_map.find(opr->input(i)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
@@ -4358,9 +4411,10 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW32)); |
|
|
|
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
return make_new_conv(inps, &conv_bias, Format::NCHW32); |
|
|
|
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW32); |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW32)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
auto try_transform_to_nchw64 = |
|
|
@@ -4368,17 +4422,18 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
// fint4XWint4 and fuint4XWint4 |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
(opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
opr->input(0)->dtype().enumv() == |
|
|
|
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
new_inp[0]->dtype().enumv() == |
|
|
|
DTypeEnum::Quantized4Asymm) && |
|
|
|
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS4; |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= opr->input(3)->dtype().enumv() == |
|
|
|
opr->input(0)->dtype().enumv(); |
|
|
|
check_dtype &= new_inp[3]->dtype().enumv() == |
|
|
|
new_inp[0]->dtype().enumv(); |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
size_t out_channels = opr->input(1)->shape()[0]; |
|
|
@@ -4388,7 +4443,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return nullptr; |
|
|
|
auto inps = new_inp; |
|
|
|
auto process = [&](size_t i) -> VarNode* { |
|
|
|
auto iter = format_map.find(opr->input(i)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
auto ovar = RelayoutPlaceholder::make( |
|
|
|
inps[i], |
|
|
@@ -4415,15 +4470,16 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
for (size_t i = 0; i < inps.size(); ++i) { |
|
|
|
inps[i] = process(i); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW64)); |
|
|
|
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
return make_new_conv(inps, &conv_bias, Format::NCHW64); |
|
|
|
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW64); |
|
|
|
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW64)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
// replace rule for conv bias opr |
|
|
|
auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, |
|
|
|
try_transform_to_nchw32, |
|
|
|
try_transform_to_nchw64]( |
|
|
|
try_transform_to_nchw64, try_transform_to_nchw]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
using Param = megdnn::param::ConvBias; |
|
|
@@ -4435,16 +4491,18 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
VarNode* new_var = nullptr; |
|
|
|
if ((new_var = try_transform_to_nchw32(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nchw4(opr, new_inp)) || |
|
|
|
(new_var = try_transform_to_nchw64(opr, new_inp))) { |
|
|
|
(new_var = try_transform_to_nchw64(opr, new_inp))|| |
|
|
|
(new_var = try_transform_to_nchw(opr, new_inp))) { |
|
|
|
return new_var->owner_opr(); |
|
|
|
} else { |
|
|
|
mgb_assert( |
|
|
|
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && |
|
|
|
opr->input(0)->dtype().enumv() != |
|
|
|
new_inp[0]->dtype().enumv() != DTypeEnum::QuantizedS8 && |
|
|
|
new_inp[0]->dtype().enumv() != |
|
|
|
DTypeEnum::QuantizedS4 && |
|
|
|
opr->input(0)->dtype().enumv() != |
|
|
|
DTypeEnum::Quantized4Asymm, |
|
|
|
"invalid data type(%s)", opr->input(0)->dtype().name()); |
|
|
|
new_inp[0]->dtype().enumv() != |
|
|
|
DTypeEnum::Quantized4Asymm && |
|
|
|
new_inp[0]->dtype().enumv() != DTypeEnum::Float32, |
|
|
|
"invalid data type(%s)", new_inp[0]->dtype().name()); |
|
|
|
bool shape_changed = false; |
|
|
|
for (const auto& i : new_inp) { |
|
|
|
if (format_map.count(i->owner_opr()) > 0) { |
|
|
@@ -4471,9 +4529,9 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
"only have 2 input vars(got:%zu)", |
|
|
|
new_inp.size()); |
|
|
|
auto& deconv = opr->cast_final_safe<opr::ConvolutionBackwardData>(); |
|
|
|
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(opr->input(1)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[1]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
cur = Format::NCHW; |
|
|
|
} else { |
|
|
@@ -4506,13 +4564,15 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW4); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW4)); |
|
|
|
|
|
|
|
auto param = deconv.param(); |
|
|
|
param.format = Format::NCHW4; |
|
|
|
auto new_deconv = opr::ConvolutionBackwardData::make( |
|
|
|
inps[0], inps[1], param, deconv.execution_policy(), |
|
|
|
deconv.config()); |
|
|
|
return new_deconv.node()->owner_opr(); |
|
|
|
auto ret = new_deconv.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW4)); |
|
|
|
return ret; |
|
|
|
} else { |
|
|
|
bool shape_changed = false; |
|
|
|
for (const auto& i : new_inp) { |
|
|
@@ -4538,7 +4598,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
bool same_format = true; |
|
|
|
bool first_touch = false; |
|
|
|
Format format(Format::NCHW); |
|
|
|
for (const auto& i : opr->input()) { |
|
|
|
for (const auto& i : new_inp) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(i->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
@@ -4557,10 +4617,11 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
} |
|
|
|
} |
|
|
|
if (same_format) { |
|
|
|
auto ret = serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
if (format != Format::NCHW) |
|
|
|
format_map.insert(std::make_pair(opr, format)); |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
format_map.insert(std::make_pair(ret, format)); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
Format max_format(Format::NCHW); |
|
|
@@ -4592,7 +4653,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
}; |
|
|
|
auto inps = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
auto iter = format_map.find(opr->input(i)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
Format cur; |
|
|
|
if (iter != format_map.end()) { |
|
|
|
cur = iter->second; |
|
|
@@ -4603,9 +4664,10 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
if (max_format != Format::NCHW) |
|
|
|
format_map.insert(std::make_pair(opr, max_format)); |
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
format_map.insert(std::make_pair(ret, max_format)); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
// elemwise like |
|
|
|
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; |
|
|
@@ -4619,10 +4681,10 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); |
|
|
|
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(opr->input(0)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[0]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
cur = Format::NCHW; |
|
|
|
} else { |
|
|
@@ -4651,7 +4713,6 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW64); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW64)); |
|
|
|
auto param = warp.param(); |
|
|
|
param.format = Format::NCHW64; |
|
|
|
SymbolVar new_warp; |
|
|
@@ -4665,10 +4726,12 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[0], inps[1], inps[2], inps[3], param, |
|
|
|
warp.config()); |
|
|
|
} |
|
|
|
return new_warp.node()->owner_opr(); |
|
|
|
} else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
auto ret = new_warp.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW64)); |
|
|
|
return ret; |
|
|
|
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(opr->input(0)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[0]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
cur = Format::NCHW; |
|
|
|
} else { |
|
|
@@ -4697,7 +4760,7 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW4); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW4)); |
|
|
|
|
|
|
|
auto param = warp.param(); |
|
|
|
param.format = Format::NCHW4; |
|
|
|
SymbolVar new_warp; |
|
|
@@ -4711,7 +4774,9 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
inps[0], inps[1], inps[2], inps[3], param, |
|
|
|
warp.config()); |
|
|
|
} |
|
|
|
return new_warp.node()->owner_opr(); |
|
|
|
auto ret = new_warp.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW4)); |
|
|
|
return ret; |
|
|
|
} else { |
|
|
|
bool shape_changed = false; |
|
|
|
for (const auto& i : new_inp) { |
|
|
@@ -4733,10 +4798,10 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); |
|
|
|
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(opr->input(0)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[0]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
cur = Format::NCHW; |
|
|
|
} else { |
|
|
@@ -4765,21 +4830,23 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
default: |
|
|
|
mgb_assert(cur == Format::NCHW64); |
|
|
|
} |
|
|
|
format_map.insert(std::make_pair(opr, Format::NCHW64)); |
|
|
|
|
|
|
|
auto param = pooling.param(); |
|
|
|
param.format = Format::NCHW64; |
|
|
|
auto new_pool = |
|
|
|
opr::PoolingForward::make(inps[0], param, pooling.config()); |
|
|
|
return new_pool.node()->owner_opr(); |
|
|
|
} else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
auto ret = new_pool.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, Format::NCHW64)); |
|
|
|
return ret; |
|
|
|
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
Format cur; |
|
|
|
auto iter = format_map.find(opr->input(0)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[0]->owner_opr()); |
|
|
|
if (iter == format_map.end()) { |
|
|
|
cur = Format::NCHW; |
|
|
|
} else { |
|
|
|
cur = iter->second; |
|
|
|
} |
|
|
|
size_t in_channels = opr->input(0)->shape()[1]; |
|
|
|
size_t in_channels = new_inp[0]->shape()[1]; |
|
|
|
bool use_nchw32 = false; |
|
|
|
auto inps = new_inp; |
|
|
|
using LayoutType = RelayoutPlaceholder::LayoutType; |
|
|
@@ -4805,12 +4872,14 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
mgb_assert(cur == Format::NCHW4); |
|
|
|
} |
|
|
|
Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; |
|
|
|
format_map.insert(std::make_pair(opr, out_format)); |
|
|
|
|
|
|
|
auto param = pooling.param(); |
|
|
|
param.format = out_format; |
|
|
|
auto new_pool = |
|
|
|
opr::PoolingForward::make(inps[0], param, pooling.config()); |
|
|
|
return new_pool.node()->owner_opr(); |
|
|
|
auto ret = new_pool.node()->owner_opr(); |
|
|
|
format_map.insert(std::make_pair(ret, out_format)); |
|
|
|
return ret; |
|
|
|
} else { |
|
|
|
bool shape_changed = false; |
|
|
|
for (const auto& i : new_inp) { |
|
|
@@ -4838,9 +4907,9 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto inps = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
auto iter = format_map.find(opr->input(i)->owner_opr()); |
|
|
|
auto iter = format_map.find(new_inp[i]->owner_opr()); |
|
|
|
auto fmt = iter != format_map.end()?iter->second:Format::NCHW; |
|
|
|
if (iter != format_map.end()) { |
|
|
|
auto fmt = iter->second; |
|
|
|
switch (fmt) { |
|
|
|
case Format::NCHW4: |
|
|
|
inps[i] = RelayoutPlaceholder::make( |
|
|
@@ -4867,7 +4936,8 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
|
|
|
|
replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw; |
|
|
|