GitOrigin-RevId: 982dee36e1
release-0.6
@@ -2049,16 +2049,17 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||||
return new_opr; | return new_opr; | ||||
} | } | ||||
}; | }; | ||||
auto replace_concat_opr = [=](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
//! When input change and all input can convert to nchwxx, this opr will run | |||||
//! in nchwxx mode, else it will run in nchw mode, for example concat and | |||||
//! elemwise opr | |||||
auto replace_multi_inp_opr = [=](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
bool has_inp_changed = false; | bool has_inp_changed = false; | ||||
bool can_exec_ncwxx = true; | bool can_exec_ncwxx = true; | ||||
for (size_t i = 0; i < opr->input().size(); i++) { | for (size_t i = 0; i < opr->input().size(); i++) { | ||||
if (new_inp[i]->shape().ndim == 5) { | if (new_inp[i]->shape().ndim == 5) { | ||||
has_inp_changed = true; | has_inp_changed = true; | ||||
break; | |||||
} else if (new_inp[i]->shape().ndim == 4) { | } else if (new_inp[i]->shape().ndim == 4) { | ||||
if (new_inp[i]->shape()[1] % pack_c_size != 0) { | if (new_inp[i]->shape()[1] % pack_c_size != 0) { | ||||
can_exec_ncwxx = false; | can_exec_ncwxx = false; | ||||
@@ -2095,36 +2096,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||||
} | } | ||||
}; | }; | ||||
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
bool has_inp_changed = false; | |||||
for (size_t i = 0; i < opr->input().size(); i++) { | |||||
if (new_inp[i]->shape().ndim == 5) { | |||||
has_inp_changed = true; | |||||
break; | |||||
} | |||||
} | |||||
if (has_inp_changed) { | |||||
auto temp_inp = new_inp; | |||||
for (size_t i = 0; i < opr->input().size(); i++) { | |||||
if (new_inp[i]->shape().ndim == 4) { | |||||
auto new_var = RelayoutPlaceholder::make( | |||||
new_inp[i], src_to_nchwxx_mode); | |||||
temp_inp[i] = new_var.node(); | |||||
} else { | |||||
mgb_assert((new_inp[i]->shape().ndim == 5) || | |||||
new_inp[i]->shape().is_scalar()); | |||||
} | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, temp_inp, | |||||
opr->config()); | |||||
} else { | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
}; | |||||
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, | auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, | ||||
const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
@@ -2146,11 +2117,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||||
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | ||||
replace_func[opr::Concat::typeinfo()] = replace_concat_opr; | |||||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | |||||
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | |||||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | |||||
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | |||||
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | |||||
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | |||||
replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; | |||||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr; | |||||
replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr; | |||||
//! not support yet | //! not support yet | ||||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = | replace_func[opr::ConvolutionBackwardData::typeinfo()] = | ||||
relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
@@ -2164,6 +2135,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | replace_func[opr::WarpPerspectiveForward::typeinfo()] = | ||||
relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; | |||||
} | } | ||||
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | ||||
@@ -2948,6 +2948,90 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | ||||
} | } | ||||
TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) { | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_x1 = gen({1, 8, 16, 16}, cn); | |||||
auto host_x2 = gen({1, 1, 16, 16}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||||
opr::Convolution::Param param_conv; | |||||
param_conv.pad_h = param_conv.pad_w = 1; | |||||
auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||||
auto b = mkvar("b", {1, 1, 16, 16}), | |||||
y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU); | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw44(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW44MultiInput.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_x1 = gen({1, 8, 16, 16}, cn); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||||
opr::Convolution::Param param_conv; | |||||
param_conv.pad_h = param_conv.pad_w = 1; | |||||
auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||||
auto y = opr::Reshape::make(conv1, {8, 16 * 16}); | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw44(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW44Reshape.json")); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | |||||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||