|
|
@@ -1415,6 +1415,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, |
|
|
|
return new_var; |
|
|
|
} |
|
|
|
|
|
|
|
//! FIXME: All float oprs do not support NCHW4. Supports it in the future plz. |
|
|
|
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
MIDOUT_B("EnableNCHW4Pass::make") |
|
|
|
auto ret = std::make_unique<EnableNCHW4Pass>(); |
|
|
@@ -1467,6 +1468,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
auto replace_conv_opr = [trans_nchw4, conv_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); |
|
|
|
if (conv_opr.param().format != |
|
|
@@ -1503,6 +1508,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
src_to_nchw4_mode]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& batch_conv_bias_opr = |
|
|
|
opr->cast_final_safe<opr::BatchConvBiasForward>(); |
|
|
@@ -1580,6 +1589,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
src_to_nchw4_mode]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
if (conv_bias_opr.param().format != |
|
|
@@ -1647,6 +1660,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
}; |
|
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
bool has_inp_changed = false; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
@@ -1691,6 +1708,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
}; |
|
|
|
auto replace_pooling_opr = [](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
using Param = opr::PoolingForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
@@ -1716,6 +1737,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
}; |
|
|
|
auto replace_resize_opr = [](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
using Param = opr::ResizeForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
@@ -1738,6 +1763,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
}; |
|
|
|
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
using Param = opr::WarpPerspective::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
@@ -3127,4 +3156,4 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |