GitOrigin-RevId: 1813753b14
release-0.5
@@ -1598,18 +1598,103 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
} | } | ||||
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); | ||||
}; | }; | ||||
auto replace_pooling_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::PoolingForward::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); | |||||
mgb_assert(pooling.param().format == Format::NCHW, | |||||
"ConvertFormat Pass only support converting NCHW to NCHW4."); | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = pooling.param(); | |||||
new_param.format = Format::NCHW4; | |||||
auto new_pooling = | |||||
opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
mgb_assert(new_pooling.shape().ndim == 5, | |||||
"out var of Pooling opr after transform must be 5 (got: " | |||||
"%zu).", | |||||
new_pooling.shape().ndim); | |||||
return new_pooling.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto replace_resize_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::ResizeForward::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& resize = opr->cast_final_safe<opr::ResizeForward>(); | |||||
mgb_assert(resize.param().format == Format::NCHW, | |||||
"ConvertFormat Pass only support converting NCHW to NCHW4."); | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = resize.param(); | |||||
new_param.format = Format::NCHW4; | |||||
auto new_resize = opr::ResizeForward::make( | |||||
new_inp[0], new_inp[1], new_param, opr->config()); | |||||
mgb_assert(new_resize.shape().ndim == 5, | |||||
"out var of Resize opr after transform must be 5 (got: " | |||||
"%zu).", | |||||
new_resize.shape().ndim); | |||||
return new_resize.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
using Param = opr::WarpPerspective::Param; | |||||
using Format = Param::Format; | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); | |||||
mgb_assert(warp.param().format == Format::NCHW, | |||||
"ConvertFormat Pass only support converting NCHW to NCHW4."); | |||||
if (new_inp[0]->shape().ndim == 5) { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | |||||
auto new_param = warp.param(); | |||||
new_param.format = Format::NCHW4; | |||||
SymbolVar new_warp; | |||||
if (new_inp.size() == 3) { | |||||
new_warp = opr::WarpPerspectiveForward::make( | |||||
new_inp[0], new_inp[1], nullptr, new_inp[2], new_param, | |||||
opr->config()); | |||||
} else { | |||||
mgb_assert(new_inp.size() == 4); | |||||
new_warp = opr::WarpPerspectiveForward::make( | |||||
new_inp[0], new_inp[1], new_inp[2], new_inp[3], | |||||
new_param, opr->config()); | |||||
} | |||||
mgb_assert(new_warp.shape().ndim == 5, | |||||
"out var of WarpPerspective opr after transform must be " | |||||
"5 (got: " | |||||
"%zu).", | |||||
new_warp.shape().ndim); | |||||
return new_warp.node()->owner_opr(); | |||||
} | |||||
auto new_opr = | |||||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
return new_opr; | |||||
}; | |||||
auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
//! supportted nchw4 | //! supportted nchw4 | ||||
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
replace_func[opr::BatchConvBias::typeinfo()] = | replace_func[opr::BatchConvBias::typeinfo()] = | ||||
replace_batch_conv_bias_opr; | replace_batch_conv_bias_opr; | ||||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||||
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
replace_warp_perspective_opr; | |||||
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; | ||||
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | ||||
//! not supported nchw4 | //! not supported nchw4 | ||||
replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; | |||||
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::ConvolutionBackwardData::typeinfo()] = | replace_func[opr::ConvolutionBackwardData::typeinfo()] = | ||||
relayout_inp_to_nchw; | relayout_inp_to_nchw; | ||||
@@ -1619,9 +1704,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | ||||
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; | |||||
replace_func[opr::WarpPerspectiveForward::typeinfo()] = | |||||
relayout_inp_to_nchw; | |||||
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -2430,14 +2430,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
auto conv1 = opr::ConvBiasForward::make( | auto conv1 = opr::ConvBiasForward::make( | ||||
x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
x, w1, b1, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
// group | // group | ||||
// icpg != 1 && ocpg != 1 | // icpg != 1 && ocpg != 1 | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | ||||
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | ||||
auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, | |||||
param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto conv2 = opr::ConvBiasForward::make( | |||||
conv1, w2, b2, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | auto y = opr::TypeCvt::make(conv2, dtype::Float32()); | ||||
@@ -2455,8 +2457,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
graph->compile({{y_opt, {}}}) | graph->compile({{y_opt, {}}}) | ||||
->to_json() | ->to_json() | ||||
->writeto_fpath( | |||||
output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW4GPU.json")); | |||||
HostTensorND host_y, host_y_opt; | HostTensorND host_y, host_y_opt; | ||||
auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
@@ -2467,6 +2469,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||||
#endif | #endif | ||||
TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) { | |||||
auto cn = CompNode::load("xpu0"); | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvarf32 = [&](const char* name, const TensorShape& shp) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name); | |||||
}; | |||||
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); | |||||
opr::ConvBias::Param param_conv_bias; | |||||
param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; | |||||
param_conv_bias.stride_h = param_conv_bias.stride_w = 1; | |||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||||
param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
// dense | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv1 = opr::ConvBiasForward::make( | |||||
x, w1, b1, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
// test Resize | |||||
auto shape_of = opr::GetVarShape::make(x); | |||||
auto subtensor = opr::Subtensor::make( | |||||
shape_of, {opr::Subtensor::AxisIndexer::make_interval( | |||||
0, x.make_scalar(2), None, x.make_scalar(1))}); | |||||
opr::Resize::Param param_resize; | |||||
param_resize.format = opr::Resize::Param::Format::NCHW; | |||||
auto resize = opr::ResizeForward::make(conv1, subtensor * 2, param_resize); | |||||
// test WarpPerspective | |||||
auto mat = mkcvarf32("mat", {2, 3, 3}), | |||||
warp = opr::WarpPerspectiveForward::make( | |||||
resize, mat, nullptr, cg::var_from_tensor_shape(x, {32, 32})); | |||||
opr::Pooling::Param pool_param; | |||||
pool_param.format = opr::Pooling::Param::Format::NCHW; | |||||
// test Pooling | |||||
auto pool = opr::Pooling::make(warp, pool_param); | |||||
// group | |||||
// icpg != 1 && ocpg != 1 | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv2 = opr::ConvBiasForward::make( | |||||
pool, w2, b2, param_conv_bias, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||||
auto add = opr::ElemwiseMultiType::make( | |||||
{conv1, conv2}, {opr::ElemwiseMultiType::Param::Mode::QADD}, | |||||
OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); | |||||
auto y = opr::TypeCvt::make(add, dtype::Float32()); | |||||
SymbolVar y_opt; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw4(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
} | |||||
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); | |||||
ASSERT_EQ(2u, nr_dimshuffle); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
ASSERT_EQ(opr::ResizeForward::Param::Format::NCHW4, | |||||
find_opr<opr::ResizeForward>(y_opt).param().format); | |||||
ASSERT_EQ(opr::WarpPerspectiveForward::Param::Format::NCHW4, | |||||
find_opr<opr::WarpPerspectiveForward>(y_opt).param().format); | |||||
ASSERT_EQ(opr::PoolingForward::Param::Format::NCHW4, | |||||
find_opr<opr::PoolingForward>(y_opt).param().format); | |||||
} | |||||
TEST(TestGoptInference, ConvertFormatNCHW4) { | TEST(TestGoptInference, ConvertFormatNCHW4) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
auto cn = CompNode::load("cpu0"); | auto cn = CompNode::load("cpu0"); | ||||
@@ -2481,7 +2567,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { | |||||
}; | }; | ||||
auto x = mkvar("x", {2, 4, 16, 16}); | auto x = mkvar("x", {2, 4, 16, 16}); | ||||
// ConvBias | |||||
// ConvBias test dense | |||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | ||||