GitOrigin-RevId: 982dee36e1
release-0.6
@@ -2049,16 +2049,17 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
return new_opr; | |||
} | |||
}; | |||
auto replace_concat_opr = [=](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
//! When input change and all input can convert to nchwxx, this opr will run | |||
//! in nchwxx mode, else it will run in nchw mode, for example concat and | |||
//! elemwise opr | |||
auto replace_multi_inp_opr = [=](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
bool has_inp_changed = false; | |||
bool can_exec_ncwxx = true; | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (new_inp[i]->shape().ndim == 5) { | |||
has_inp_changed = true; | |||
break; | |||
} else if (new_inp[i]->shape().ndim == 4) { | |||
if (new_inp[i]->shape()[1] % pack_c_size != 0) { | |||
can_exec_ncwxx = false; | |||
@@ -2095,36 +2096,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
} | |||
}; | |||
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
bool has_inp_changed = false; | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (new_inp[i]->shape().ndim == 5) { | |||
has_inp_changed = true; | |||
break; | |||
} | |||
} | |||
if (has_inp_changed) { | |||
auto temp_inp = new_inp; | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (new_inp[i]->shape().ndim == 4) { | |||
auto new_var = RelayoutPlaceholder::make( | |||
new_inp[i], src_to_nchwxx_mode); | |||
temp_inp[i] = new_var.node(); | |||
} else { | |||
mgb_assert((new_inp[i]->shape().ndim == 5) || | |||
new_inp[i]->shape().is_scalar()); | |||
} | |||
} | |||
return serialization::copy_opr_shallow(*opr, temp_inp, | |||
opr->config()); | |||
} else { | |||
return serialization::copy_opr_shallow(*opr, new_inp, | |||
opr->config()); | |||
} | |||
}; | |||
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
@@ -2146,11 +2117,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | |||
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | |||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||
replace_func[opr::Concat::typeinfo()] = replace_concat_opr; | |||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | |||
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | |||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | |||
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | |||
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | |||
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | |||
replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; | |||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr; | |||
replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr; | |||
//! not support yet | |||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = | |||
relayout_inp_to_nchw; | |||
@@ -2164,6 +2135,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||
relayout_inp_to_nchw; | |||
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | |||
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; | |||
} | |||
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | |||
@@ -2948,6 +2948,90 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto host_x1 = gen({1, 8, 16, 16}, cn); | |||
auto host_x2 = gen({1, 1, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto b = mkvar("b", {1, 1, 16, 16}), | |||
y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU); | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.ConvertFormatNCHW44MultiInput.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto host_x1 = gen({1, 8, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 8, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto y = opr::Reshape::make(conv1, {8, 16 * 16}); | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.ConvertFormatNCHW44Reshape.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||