@@ -1589,6 +1589,67 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
return new_opr; | |||
}; | |||
auto replace_concat_opr = [&relayout_inp_to_chw]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
//! map nchw axis to CD4 axis(n h c/4 w 4) | |||
auto axis_nchw_to_cd4_map = [=](int32_t org_axis) -> int32_t { | |||
mgb_assert(org_axis >= 0 && org_axis <= 3); | |||
int32_t ret = 0; | |||
if (0 == org_axis) { | |||
ret = 0; | |||
} else if (1 == org_axis) { | |||
ret = 2; | |||
} else if (2 == org_axis) { | |||
ret = 1; | |||
} else if (3 == org_axis) { | |||
mgb_throw(InternalError, | |||
"Do not support axis=3 for concat bypass for CD4!"); | |||
} else { | |||
mgb_throw(InternalError, | |||
"Do not support axis for concat pass, may input is " | |||
"not NCHW format!"); | |||
} | |||
return ret; | |||
}; | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
auto nchw_axis = opr->cast_final_safe<opr::Concat>().param().axis; | |||
if (nchw_axis < 0 || nchw_axis > 3) { | |||
mgb_log_warn("concat pass fallback to relayout chw\n"); | |||
return relayout_inp_to_chw(opr, new_inp); | |||
} | |||
bool can_exec_cd4 = true; | |||
//! only consider OpenCL CD4, if other backend has relayout performance | |||
//! issue, may add other bypass format | |||
for (size_t i = 0; i < opr->input().size(); i++) { | |||
if (opr->input(i)->format().type() != TensorFormat::Type::DEFAULT || | |||
opr->input(i)->shape()[1] % 4 != 0 || | |||
new_inp[i]->shape().ndim != 5 || | |||
new_inp[i]->format().type() != | |||
TensorFormat::Type::IMAGE2D_PACK4 || | |||
nchw_axis == 3) { | |||
can_exec_cd4 = false; | |||
break; | |||
} | |||
} | |||
if (!can_exec_cd4) { | |||
mgb_log_warn("concat pass fallback to relayout chw"); | |||
return relayout_inp_to_chw(opr, new_inp); | |||
} | |||
megdnn::param::Axis param; | |||
//! now only support nchw bypass to CD4 | |||
mgb_log_warn("concat pass bypass to CD4"); | |||
param.axis = axis_nchw_to_cd4_map(nchw_axis); | |||
return opr::Concat::make(VarNodeArrayView(new_inp), param, | |||
opr->config()) | |||
.node() | |||
->owner_opr(); | |||
}; | |||
auto replace_elemwise_opr = [&relayout_inp_to_chw]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
@@ -1654,7 +1715,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; | |||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | |||
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw; | |||
replace_func[opr::Concat::typeinfo()] = replace_concat_opr; | |||
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw; | |||
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw; | |||
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw; | |||
@@ -1591,6 +1591,77 @@ TEST(TestGoptInference, ConvertFormatPadIC) { | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptInference, concatbypass) { | |||
// hwcd4 is only supported in naive handle | |||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto host_inp1 = gen({1, 6, 16, 16}, cn), | |||
host_inp2 = gen({1, 6, 32, 32}, cn); | |||
auto inp1 = opr::Host2DeviceCopy::make(*graph, host_inp1), | |||
inp2 = opr::Host2DeviceCopy::make(*graph, host_inp2); | |||
auto shape_tmp = mkcvar("tmp", {32, 32}); | |||
auto shape_of = opr::GetVarShape::make(shape_tmp); | |||
opr::Resize::Param param_resize; | |||
param_resize.format = opr::Resize::Param::Format::NCHW; | |||
auto resize = opr::ResizeForward::make(inp1, shape_of, param_resize); | |||
//! this concat should forward to chw | |||
auto concat = opr::Concat::make({inp2, resize}, 1); | |||
opr::Convolution::Param param; | |||
param.pad_h = param.pad_w = 1; | |||
param.sparse = opr::Convolution::Param::Sparse::DENSE; | |||
auto w1 = mkcvar("w1", {12, 12, 3, 3}); | |||
auto w2 = mkcvar("w1", {12, 24, 3, 3}); | |||
auto y = opr::Convolution::make(concat, w1, param); | |||
//! this concat should bypass CD4 | |||
y = opr::Concat::make({y, y}, 0); | |||
y = opr::Convolution::make(y, w1, param); | |||
//! this concat should bypass CD4 | |||
y = opr::Concat::make({y, y}, 1); | |||
y = opr::Convolution::make(y, w2, param); | |||
//! this concat should bypass CD4 | |||
y = opr::Concat::make({y, y}, 2); | |||
y = opr::Convolution::make(y, w1, param); | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nhwcd4(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
size_t relayout_format_nr = 0; | |||
auto cb = [&](cg::OperatorNodeBase* opr) { | |||
if (opr->try_cast_final<opr::Convolution>()) { | |||
auto conv_inputs = opr->input(); | |||
for (auto& input : conv_inputs) { | |||
if (std::string::npos != | |||
std::string(input->cname()).find("relayout_format")) { | |||
relayout_format_nr++; | |||
} | |||
} | |||
} | |||
return true; | |||
}; | |||
func->iter_opr_seq(cb); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
ASSERT_EQ(1, relayout_format_nr); | |||
} | |||
TEST(TestGoptInference, ConvertBatchNormPass) { | |||
auto cn = CompNode::load("cpu0"); | |||