GitOrigin-RevId: 89083be200
tags/v1.0.0-rc1
@@ -1811,7 +1811,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||
} | |||
return new_var; | |||
} | |||
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv | |||
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, | |||
const size_t pack_c_size, const size_t fh, | |||
const size_t fw, const size_t stride_h, | |||
const size_t stride_w) { | |||
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && | |||
stride_h == stride_w && (stride_h == 1 || stride_h == 2) && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
} | |||
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
using TestFilterResult = std::pair<TransType, RelayoutMode>; | |||
@@ -1848,15 +1856,19 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, | |||
hybrid_nchw_nchwxx]( | |||
const megdnn::param::Convolution::Sparse conv_mode, | |||
const VarNode* filter) -> TestFilterResult { | |||
const VarNode* filter, const size_t stride_h, | |||
const size_t stride_w) -> TestFilterResult { | |||
TestFilterResult ret{TransType::TRANS_NONE, {}}; | |||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | |||
size_t IC = filter->shape()[1]; | |||
size_t OC = filter->shape()[0]; | |||
size_t IC = filter->shape()[1]; | |||
size_t FH = filter->shape()[2]; | |||
size_t FW = filter->shape()[3]; | |||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | |||
ret.first = TransType::TRANS_PURE_NCHWXX; | |||
ret.second = weight_to_nchwxx_mode_dense; | |||
} else if (IC < pack_c_size && OC % pack_c_size == 0) { | |||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||
stride_w)) { | |||
ret.first = TransType::TRANS_HYBIRD_NCHWXX; | |||
ret.second = hybrid_nchw_nchwxx; | |||
} | |||
@@ -1883,7 +1895,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
mgb_assert(conv_opr.param().format == | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]); | |||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], | |||
conv_opr.param().stride_h, | |||
conv_opr.param().stride_w); | |||
//! can not trans to nchwxx | |||
if (is_trans.first == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -1957,8 +1971,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
auto is_trans = | |||
test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]); | |||
auto is_trans = test_trans_nchwxx( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||
//! can not trans to nchwxx | |||
if (is_trans.first == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2199,17 +2214,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
constexpr size_t pack_c_size = 4_z; | |||
auto test_trans_nchw44_dot = | |||
[](const megdnn::param::Convolution::Sparse conv_mode, | |||
const VarNode* filter) -> TestTransResult { | |||
const VarNode* filter, const size_t stride_h, | |||
const size_t stride_w) -> TestTransResult { | |||
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | |||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | |||
size_t IC = filter->shape()[1]; | |||
size_t OC = filter->shape()[0]; | |||
size_t IC = filter->shape()[1]; | |||
size_t FH = filter->shape()[2]; | |||
size_t FW = filter->shape()[3]; | |||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | |||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
} else if (IC < pack_c_size && OC % pack_c_size == 0) { | |||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||
stride_w)) { | |||
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | |||
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
@@ -2241,8 +2260,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to " | |||
"NCHW44_DOT"); | |||
auto is_trans = | |||
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
conv_opr.param().stride_w); | |||
//! can not trans to nchwxx | |||
if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2315,8 +2335,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
auto is_trans = | |||
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||
//! can not trans to nchwxx | |||
if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2976,11 +2976,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
//! Hybrid nchw88 mode | |||
//! Hybrid nchw44 mode | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
opr::ConvBias::Param param_conv_bias_stride4; | |||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||
//! channel wise | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
@@ -3015,22 +3020,35 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||
SymbolVar y_opt; | |||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
unpack_vector(gopt::optimize_for_inference( | |||
{y, conv1, conv1_f4, conv1_s4, conv2}, options), | |||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(conv2_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt).param().format); | |||
graph->compile({{y_opt, {}}}) | |||
graph->compile({{y_opt, {}}, {conv2, {}}}) | |||
->to_json() | |||
->writeto_fpath( | |||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | |||
HostTensorND host_y_opt, host_y; | |||
HostTensorND host_conv1; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
make_callback_copy(y_opt, host_y_opt), | |||
make_callback_copy(conv1, host_conv1)}); | |||
func->execute(); | |||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
@@ -3140,11 +3158,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
//! Hybrid nchw88 mode | |||
//! Hybrid nchw44 mode | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
opr::ConvBias::Param param_conv_bias_stride4; | |||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||
//! channel wise | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
@@ -3179,12 +3202,20 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||
SymbolVar y_opt; | |||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw44_dot(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, | |||
options), | |||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt).param().format); | |||