|
|
@@ -1171,16 +1171,22 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { |
|
|
|
opr::Elemwise::Param::Mode::RELU); |
|
|
|
param.pad_h = param.pad_w = 1; |
|
|
|
auto w2 = mkcvar("w2", {4, 4, 3, 3}), |
|
|
|
y = opr::Convolution::make(elem, w2, param); |
|
|
|
y = opr::Convolution::make(elem, w2, param), |
|
|
|
z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); |
|
|
|
|
|
|
|
SymbolVar y_opt; |
|
|
|
SymbolVar y_opt, z_opt; |
|
|
|
auto options = gopt::OptimizeForInferenceOptions{}; |
|
|
|
options.enable_nhwcd4(); |
|
|
|
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); |
|
|
|
unpack_vector(gopt::optimize_for_inference({z}, options), z_opt); |
|
|
|
|
|
|
|
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, |
|
|
|
find_opr<opr::Convolution>(y_opt).param().format); |
|
|
|
|
|
|
|
ASSERT_EQ(TensorFormat::Type::DEFAULT, |
|
|
|
find_opr<opr::AxisAddRemove>(z_opt).input(0)->format().type()); |
|
|
|
ASSERT_EQ(4, find_opr<opr::AxisAddRemove>(z_opt).input(0)->shape().ndim); |
|
|
|
|
|
|
|
graph->compile({{y_opt, {}}}) |
|
|
|
->to_json() |
|
|
|
->writeto_fpath( |
|
|
|