|
|
@@ -28,6 +28,7 @@ |
|
|
|
#include "megbrain/opr/imgproc.h" |
|
|
|
#include "megbrain/opr/nn_int.h" |
|
|
|
|
|
|
|
#include "megdnn/opr_param_defs.h" |
|
|
|
#include "megdnn/tensor_format.h" |
|
|
|
|
|
|
|
#if MGB_ENABLE_TENSOR_RT |
|
|
@@ -59,19 +60,19 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, |
|
|
|
public: |
|
|
|
//! relayout type of this opr |
|
|
|
enum class LayoutType { |
|
|
|
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout |
|
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
|
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout |
|
|
|
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout |
|
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 |
|
|
|
//!< layout |
|
|
|
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to |
|
|
|
//!< nchw4 layout |
|
|
|
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout |
|
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
|
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout |
|
|
|
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout |
|
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 |
|
|
|
//!< layout |
|
|
|
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to |
|
|
|
//!< nchw4 layout |
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
|
//!< layout |
|
|
@@ -92,6 +93,10 @@ public: |
|
|
|
//!< the weight layout of input is nchw output is nchw44, special for |
|
|
|
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} |
|
|
|
WEIGHT_HYBIRD_NCHW_NCHW44, |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to |
|
|
|
//!< NCHW44_DOT layout dense |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to |
|
|
|
//!< NCHW44_DOT layout group |
|
|
|
}; |
|
|
|
|
|
|
|
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); |
|
|
@@ -268,7 +273,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[3] = inp_shape[1]; |
|
|
|
dst[4] = 8; |
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE) { |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE || |
|
|
|
layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && |
|
|
|
inp_shape[1] % 4 == 0); |
|
|
|
dst.ndim = 6; |
|
|
@@ -279,7 +286,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[4] = 4; |
|
|
|
dst[5] = 4; |
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW44_GROUP) { |
|
|
|
WEIGHT_NCHW_TO_NCHW44_GROUP || |
|
|
|
layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) { |
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 && |
|
|
|
inp_shape[2] % 4 == 0); |
|
|
|
dst.ndim = 7; |
|
|
@@ -646,6 +655,42 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3}); |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, |
|
|
|
cv(4), sub(3), sub(4)}, |
|
|
|
0), |
|
|
|
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), |
|
|
|
sub(4), cv(4), cv(4)}, |
|
|
|
0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
auto rewriter = opt.graph().make_rewriter(); |
|
|
|
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { |
|
|
@@ -1601,12 +1646,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
|
return new_var; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
|
size_t pack_c_size) { |
|
|
|
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); |
|
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
|
//! First is whether the conv can trans to nchwxx, second is the filter |
|
|
|
//! trans mode |
|
|
|
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ |
|
|
|
using RelayoutMode = RelayoutPlaceholder::LayoutType; |
|
|
|
using TestFilterResult = std::pair<TransType, RelayoutMode>; |
|
|
|
RelayoutMode weight_to_nchwxx_mode_dense = |
|
|
@@ -1954,8 +1994,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); |
|
|
|
}; |
|
|
|
|
|
|
|
ret->set_name(convter_pass_name); |
|
|
|
auto&& replace_func = ret->m_opr_replace_func; |
|
|
|
auto&& replace_func = m_opr_replace_func; |
|
|
|
//! supportted nchwxx |
|
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; |
|
|
@@ -1978,6 +2017,246 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
|
replace_func[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
|
relayout_inp_to_nchw; |
|
|
|
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
|
size_t pack_c_size) { |
|
|
|
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); |
|
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
|
std::string convter_pass_name = "conv_format_nchw88"; |
|
|
|
if (pack_c_size == 4) { |
|
|
|
convter_pass_name = "conv_format_nchw44"; |
|
|
|
} |
|
|
|
ret->fill_opr_convert_fun(pack_c_size); |
|
|
|
ret->set_name(convter_pass_name); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
/* ================ EnableNchw44DotPass =============== */ |
|
|
|
VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
|
VarNode* orig_var) const { |
|
|
|
if (!orig_var->shape().eq_shape(new_var->shape())) { |
|
|
|
return RelayoutPlaceholder::make( |
|
|
|
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) |
|
|
|
.node(); |
|
|
|
} |
|
|
|
return new_var; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<EnableNchw44DotPass> |
|
|
|
EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
auto ret = std::make_unique<EnableNchw44DotPass>(); |
|
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
|
//! First is whether the conv can trans to nchwxx, second is the filter |
|
|
|
//! trans mode |
|
|
|
using RelayoutMode = RelayoutPlaceholder::LayoutType; |
|
|
|
using TestTransResult = std::pair<TransType, RelayoutMode>; |
|
|
|
megdnn::param::ConvolutionV0::Format conv_dot_format = |
|
|
|
megdnn::param::ConvBias::Format::NCHW44_DOT; |
|
|
|
constexpr size_t pack_c_size = 4_z; |
|
|
|
auto test_trans_nchw44_dot = |
|
|
|
[](const megdnn::param::Convolution::Sparse conv_mode, |
|
|
|
const VarNode* filter) -> TestTransResult { |
|
|
|
TestTransResult ret{TransType::TRANS_NONE, {}}; |
|
|
|
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { |
|
|
|
size_t IC = filter->shape()[1]; |
|
|
|
size_t OC = filter->shape()[0]; |
|
|
|
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; |
|
|
|
} else if (IC < pack_c_size && OC % pack_c_size == 0) { |
|
|
|
ret.first = TransType::TRANS_HYBIRD_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
|
} |
|
|
|
} else { |
|
|
|
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); |
|
|
|
size_t group = filter->shape()[0]; |
|
|
|
size_t ocpg = filter->shape()[1]; |
|
|
|
size_t icpg = filter->shape()[2]; |
|
|
|
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_NONE; |
|
|
|
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); |
|
|
|
mgb_assert(conv_opr.param().format == |
|
|
|
megdnn::param::Convolution::Format::NCHW, |
|
|
|
"ConvertFormat Pass only support converting NCHW to " |
|
|
|
"NCHW44_DOT"); |
|
|
|
auto is_trans = |
|
|
|
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); |
|
|
|
//! can not trans to nchwxx |
|
|
|
if (is_trans.first == TransType::TRANS_NONE) { |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
//! if src is nchwxx, should RelayoutPlaceholder to nchw |
|
|
|
if (temp_inp[0]->shape().ndim == 5) { |
|
|
|
auto new_src = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], RelayoutMode::NCHW4_TO_NCHW); |
|
|
|
temp_inp[0] = new_src.node(); |
|
|
|
} |
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
opr->config()); |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
conv_filter = new_filter.node(); |
|
|
|
//! src trans to nchwxx mode |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
auto new_src = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_src = new_src.node(); |
|
|
|
} |
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
mgb_assert(conv_src->shape().ndim == 5 && |
|
|
|
conv_filter->shape().ndim >= 6, |
|
|
|
"The conv src dim is not trans to nchwxx"); |
|
|
|
auto new_conv_opr = opr::Convolution::make( |
|
|
|
conv_src, conv_filter, new_param, |
|
|
|
conv_opr.execution_policy(), conv_opr.config()); |
|
|
|
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_opr.shape().ndim == 5, |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
conv_filter = new_filter.node(); |
|
|
|
mgb_assert(conv_src->shape().ndim == 4 && |
|
|
|
conv_filter->shape().ndim == 5, |
|
|
|
"The src and filter is OK"); |
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
auto new_conv_opr = opr::Convolution::make( |
|
|
|
conv_src, conv_filter, new_param, |
|
|
|
conv_opr.execution_policy(), conv_opr.config()); |
|
|
|
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_opr.shape().ndim == 5, |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
mgb_assert(conv_bias_opr.param().format == |
|
|
|
megdnn::param::ConvBias::Format::NCHW, |
|
|
|
"ConvertFormat Pass only support converting NCHW to NCHWXX"); |
|
|
|
auto is_trans = |
|
|
|
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
//! can not trans to nchwxx |
|
|
|
if (is_trans.first == TransType::TRANS_NONE) { |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
//! if src is nchwxx, should RelayoutPlaceholder to nchw |
|
|
|
if (temp_inp[0]->shape().ndim == 5) { |
|
|
|
auto new_src = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], RelayoutMode::NCHW4_TO_NCHW); |
|
|
|
temp_inp[0] = new_src.node(); |
|
|
|
} |
|
|
|
//! the bias is nchwxx |
|
|
|
if (temp_inp[2]->shape().ndim == 5) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW4_TO_NCHW); |
|
|
|
temp_inp[2] = new_bias.node(); |
|
|
|
} |
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
opr->config()); |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! src trans to nchwxx mode |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
auto new_src = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_src = new_src.node(); |
|
|
|
} |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_filter->shape().ndim >= 6, |
|
|
|
"The conv_bias src dim is not trans to nchwxx"); |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv_bias dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 4 && |
|
|
|
conv_bias_filter->shape().ndim == 5); |
|
|
|
mgb_assert((conv_bias_bias->shape().ndim == 5) || |
|
|
|
conv_bias_bias->shape().is_scalar()); |
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} |
|
|
|
}; |
|
|
|
ret->fill_opr_convert_fun(4); |
|
|
|
auto&& replace_func = ret->m_opr_replace_func; |
|
|
|
//! supportted nchwxx |
|
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|