@@ -1589,6 +1589,67 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
return new_opr; | return new_opr; | ||||
}; | }; | ||||
auto replace_concat_opr = [&relayout_inp_to_chw]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
//! map nchw axis to CD4 axis(n h c/4 w 4) | |||||
auto axis_nchw_to_cd4_map = [=](int32_t org_axis) -> int32_t { | |||||
mgb_assert(org_axis >= 0 && org_axis <= 3); | |||||
int32_t ret = 0; | |||||
if (0 == org_axis) { | |||||
ret = 0; | |||||
} else if (1 == org_axis) { | |||||
ret = 2; | |||||
} else if (2 == org_axis) { | |||||
ret = 1; | |||||
} else if (3 == org_axis) { | |||||
mgb_throw(InternalError, | |||||
"Do not support axis=3 for concat bypass for CD4!"); | |||||
} else { | |||||
mgb_throw(InternalError, | |||||
"Do not support axis for concat pass, may input is " | |||||
"not NCHW format!"); | |||||
} | |||||
return ret; | |||||
}; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto nchw_axis = opr->cast_final_safe<opr::Concat>().param().axis; | |||||
if (nchw_axis < 0 || nchw_axis > 3) { | |||||
mgb_log_warn("concat pass fallback to relayout chw\n"); | |||||
return relayout_inp_to_chw(opr, new_inp); | |||||
} | |||||
bool can_exec_cd4 = true; | |||||
//! only consider OpenCL CD4, if other backend has relayout performance | |||||
//! issue, may add other bypass format | |||||
for (size_t i = 0; i < opr->input().size(); i++) { | |||||
if (opr->input(i)->format().type() != TensorFormat::Type::DEFAULT || | |||||
opr->input(i)->shape()[1] % 4 != 0 || | |||||
new_inp[i]->shape().ndim != 5 || | |||||
new_inp[i]->format().type() != | |||||
TensorFormat::Type::IMAGE2D_PACK4 || | |||||
nchw_axis == 3) { | |||||
can_exec_cd4 = false; | |||||
break; | |||||
} | |||||
} | |||||
if (!can_exec_cd4) { | |||||
mgb_log_warn("concat pass fallback to relayout chw"); | |||||
return relayout_inp_to_chw(opr, new_inp); | |||||
} | |||||
megdnn::param::Axis param; | |||||
//! now only support nchw bypass to CD4 | |||||
mgb_log_warn("concat pass bypass to CD4"); | |||||
param.axis = axis_nchw_to_cd4_map(nchw_axis); | |||||
return opr::Concat::make(VarNodeArrayView(new_inp), param, | |||||
opr->config()) | |||||
.node() | |||||
->owner_opr(); | |||||
}; | |||||
auto replace_elemwise_opr = [&relayout_inp_to_chw]( | auto replace_elemwise_opr = [&relayout_inp_to_chw]( | ||||
OperatorNodeBase* opr, | OperatorNodeBase* opr, | ||||
const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
@@ -1654,7 +1715,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; | replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr; | ||||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | ||||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_chw; | |||||
replace_func[opr::Concat::typeinfo()] = replace_concat_opr; | |||||
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_chw; | ||||
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_chw; | ||||
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw; | replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_chw; | ||||
@@ -1591,6 +1591,77 @@ TEST(TestGoptInference, ConvertFormatPadIC) { | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | ||||
} | } | ||||
TEST(TestGoptInference, concatbypass) { | |||||
// hwcd4 is only supported in naive handle | |||||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto host_inp1 = gen({1, 6, 16, 16}, cn), | |||||
host_inp2 = gen({1, 6, 32, 32}, cn); | |||||
auto inp1 = opr::Host2DeviceCopy::make(*graph, host_inp1), | |||||
inp2 = opr::Host2DeviceCopy::make(*graph, host_inp2); | |||||
auto shape_tmp = mkcvar("tmp", {32, 32}); | |||||
auto shape_of = opr::GetVarShape::make(shape_tmp); | |||||
opr::Resize::Param param_resize; | |||||
param_resize.format = opr::Resize::Param::Format::NCHW; | |||||
auto resize = opr::ResizeForward::make(inp1, shape_of, param_resize); | |||||
//! this concat should forward to chw | |||||
auto concat = opr::Concat::make({inp2, resize}, 1); | |||||
opr::Convolution::Param param; | |||||
param.pad_h = param.pad_w = 1; | |||||
param.sparse = opr::Convolution::Param::Sparse::DENSE; | |||||
auto w1 = mkcvar("w1", {12, 12, 3, 3}); | |||||
auto w2 = mkcvar("w1", {12, 24, 3, 3}); | |||||
auto y = opr::Convolution::make(concat, w1, param); | |||||
//! this concat should bypass CD4 | |||||
y = opr::Concat::make({y, y}, 0); | |||||
y = opr::Convolution::make(y, w1, param); | |||||
//! this concat should bypass CD4 | |||||
y = opr::Concat::make({y, y}, 1); | |||||
y = opr::Convolution::make(y, w2, param); | |||||
//! this concat should bypass CD4 | |||||
y = opr::Concat::make({y, y}, 2); | |||||
y = opr::Convolution::make(y, w1, param); | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nhwcd4(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
HostTensorND host_y_opt, host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
size_t relayout_format_nr = 0; | |||||
auto cb = [&](cg::OperatorNodeBase* opr) { | |||||
if (opr->try_cast_final<opr::Convolution>()) { | |||||
auto conv_inputs = opr->input(); | |||||
for (auto& input : conv_inputs) { | |||||
if (std::string::npos != | |||||
std::string(input->cname()).find("relayout_format")) { | |||||
relayout_format_nr++; | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
}; | |||||
func->iter_opr_seq(cb); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
ASSERT_EQ(1, relayout_format_nr); | |||||
} | |||||
TEST(TestGoptInference, ConvertBatchNormPass) { | TEST(TestGoptInference, ConvertBatchNormPass) { | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||