GitOrigin-RevId: 89083be200
tags/v1.0.0-rc1
@@ -1811,7 +1811,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||||
} | } | ||||
return new_var; | return new_var; | ||||
} | } | ||||
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv | |||||
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, | |||||
const size_t pack_c_size, const size_t fh, | |||||
const size_t fw, const size_t stride_h, | |||||
const size_t stride_w) { | |||||
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && | |||||
stride_h == stride_w && (stride_h == 1 || stride_h == 2) && | |||||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||||
} | |||||
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | ||||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
using TestFilterResult = std::pair<TransType, RelayoutMode>; | using TestFilterResult = std::pair<TransType, RelayoutMode>; | ||||
@@ -1848,15 +1856,19 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, | weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, | ||||
hybrid_nchw_nchwxx]( | hybrid_nchw_nchwxx]( | ||||
const megdnn::param::Convolution::Sparse conv_mode, | const megdnn::param::Convolution::Sparse conv_mode, | ||||
const VarNode* filter) -> TestFilterResult { | |||||
const VarNode* filter, const size_t stride_h, | |||||
const size_t stride_w) -> TestFilterResult { | |||||
TestFilterResult ret{TransType::TRANS_NONE, {}}; | TestFilterResult ret{TransType::TRANS_NONE, {}}; | ||||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
size_t IC = filter->shape()[1]; | |||||
size_t OC = filter->shape()[0]; | size_t OC = filter->shape()[0]; | ||||
size_t IC = filter->shape()[1]; | |||||
size_t FH = filter->shape()[2]; | |||||
size_t FW = filter->shape()[3]; | |||||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | ||||
ret.first = TransType::TRANS_PURE_NCHWXX; | ret.first = TransType::TRANS_PURE_NCHWXX; | ||||
ret.second = weight_to_nchwxx_mode_dense; | ret.second = weight_to_nchwxx_mode_dense; | ||||
} else if (IC < pack_c_size && OC % pack_c_size == 0) { | |||||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||||
stride_w)) { | |||||
ret.first = TransType::TRANS_HYBIRD_NCHWXX; | ret.first = TransType::TRANS_HYBIRD_NCHWXX; | ||||
ret.second = hybrid_nchw_nchwxx; | ret.second = hybrid_nchw_nchwxx; | ||||
} | } | ||||
@@ -1883,7 +1895,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
mgb_assert(conv_opr.param().format == | mgb_assert(conv_opr.param().format == | ||||
megdnn::param::Convolution::Format::NCHW, | megdnn::param::Convolution::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]); | |||||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], | |||||
conv_opr.param().stride_h, | |||||
conv_opr.param().stride_w); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.first == TransType::TRANS_NONE) { | if (is_trans.first == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -1957,8 +1971,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
mgb_assert(conv_bias_opr.param().format == | mgb_assert(conv_bias_opr.param().format == | ||||
megdnn::param::ConvBias::Format::NCHW, | megdnn::param::ConvBias::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
auto is_trans = | |||||
test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]); | |||||
auto is_trans = test_trans_nchwxx( | |||||
conv_bias_opr.param().sparse, new_inp[1], | |||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.first == TransType::TRANS_NONE) { | if (is_trans.first == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2199,17 +2214,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
constexpr size_t pack_c_size = 4_z; | constexpr size_t pack_c_size = 4_z; | ||||
auto test_trans_nchw44_dot = | auto test_trans_nchw44_dot = | ||||
[](const megdnn::param::Convolution::Sparse conv_mode, | [](const megdnn::param::Convolution::Sparse conv_mode, | ||||
const VarNode* filter) -> TestTransResult { | |||||
const VarNode* filter, const size_t stride_h, | |||||
const size_t stride_w) -> TestTransResult { | |||||
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | ||||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
size_t IC = filter->shape()[1]; | |||||
size_t OC = filter->shape()[0]; | size_t OC = filter->shape()[0]; | ||||
size_t IC = filter->shape()[1]; | |||||
size_t FH = filter->shape()[2]; | |||||
size_t FW = filter->shape()[3]; | |||||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | ||||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ||||
ret.relayout_mod = | ret.relayout_mod = | ||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | ||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | ||||
} else if (IC < pack_c_size && OC % pack_c_size == 0) { | |||||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||||
stride_w)) { | |||||
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | ||||
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | ||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | ||||
@@ -2241,8 +2260,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
megdnn::param::Convolution::Format::NCHW, | megdnn::param::Convolution::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to " | "ConvertFormat Pass only support converting NCHW to " | ||||
"NCHW44_DOT"); | "NCHW44_DOT"); | ||||
auto is_trans = | |||||
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); | |||||
auto is_trans = test_trans_nchw44_dot( | |||||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||||
conv_opr.param().stride_w); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.trans_type == TransType::TRANS_NONE) { | if (is_trans.trans_type == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2315,8 +2335,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
mgb_assert(conv_bias_opr.param().format == | mgb_assert(conv_bias_opr.param().format == | ||||
megdnn::param::ConvBias::Format::NCHW, | megdnn::param::ConvBias::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
auto is_trans = | |||||
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); | |||||
auto is_trans = test_trans_nchw44_dot( | |||||
conv_bias_opr.param().sparse, new_inp[1], | |||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.trans_type == TransType::TRANS_NONE) { | if (is_trans.trans_type == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2976,11 +2976,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||||
auto host_x = gen({2, 3, 16, 16}, cn); | auto host_x = gen({2, 3, 16, 16}, cn); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
//! Hybrid nchw88 mode | |||||
//! Hybrid nchw44 mode | |||||
opr::Convolution::Param param_conv; | opr::Convolution::Param param_conv; | ||||
param_conv.pad_h = param_conv.pad_w = 1; | param_conv.pad_h = param_conv.pad_w = 1; | ||||
opr::ConvBias::Param param_conv_bias_stride4; | |||||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | auto w1 = mkcvar("w1", {8, 3, 3, 3}), | ||||
conv1 = opr::Convolution::make(x, w1, param_conv); | conv1 = opr::Convolution::make(x, w1, param_conv); | ||||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||||
//! channel wise | //! channel wise | ||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
@@ -3015,22 +3020,35 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | ||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | ||||
SymbolVar y_opt; | |||||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | auto options = gopt::OptimizeForInferenceOptions{}; | ||||
options.enable_nchw44(); | options.enable_nchw44(); | ||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
unpack_vector(gopt::optimize_for_inference( | |||||
{y, conv1, conv1_f4, conv1_s4, conv2}, options), | |||||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | ||||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(conv2_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | find_opr<opr::ConvBias>(y_opt).param().format); | ||||
graph->compile({{y_opt, {}}}) | |||||
graph->compile({{y_opt, {}}, {conv2, {}}}) | |||||
->to_json() | ->to_json() | ||||
->writeto_fpath( | ->writeto_fpath( | ||||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | output_file("TestGoptInference.ConvertFormatNCHW44.json")); | ||||
HostTensorND host_y_opt, host_y; | HostTensorND host_y_opt, host_y; | ||||
HostTensorND host_conv1; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
make_callback_copy(y_opt, host_y_opt), | |||||
make_callback_copy(conv1, host_conv1)}); | |||||
func->execute(); | func->execute(); | ||||
//! meybe go to winograd in x86-32, so set error 1e-1 | //! meybe go to winograd in x86-32, so set error 1e-1 | ||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | ||||
@@ -3140,11 +3158,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||||
auto host_x = gen({2, 3, 16, 16}, cn); | auto host_x = gen({2, 3, 16, 16}, cn); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
//! Hybrid nchw88 mode | |||||
//! Hybrid nchw44 mode | |||||
opr::Convolution::Param param_conv; | opr::Convolution::Param param_conv; | ||||
param_conv.pad_h = param_conv.pad_w = 1; | param_conv.pad_h = param_conv.pad_w = 1; | ||||
opr::ConvBias::Param param_conv_bias_stride4; | |||||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | auto w1 = mkcvar("w1", {8, 3, 3, 3}), | ||||
conv1 = opr::Convolution::make(x, w1, param_conv); | conv1 = opr::Convolution::make(x, w1, param_conv); | ||||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||||
//! channel wise | //! channel wise | ||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
@@ -3179,12 +3202,20 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | ||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | ||||
SymbolVar y_opt; | |||||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | auto options = gopt::OptimizeForInferenceOptions{}; | ||||
options.enable_nchw44_dot(); | options.enable_nchw44_dot(); | ||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, | |||||
options), | |||||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | ||||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | find_opr<opr::Convolution>(y_opt).param().format); | ||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | ||||
find_opr<opr::ConvBias>(y_opt).param().format); | find_opr<opr::ConvBias>(y_opt).param().format); | ||||