diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 6cce7359..aebedd24 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -122,6 +122,14 @@ public: NCHW_TO_NCHW64, //! VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWC; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + }; + reformat[LayoutType::NCHW4_TO_NHWC] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = + opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::NCHW32_TO_NHWC] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::NCHW64_TO_NHWC] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::NHWC_TO_NCHW] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto y = opr::Dimshuffle::make(x, {0, 3, 1, 2}); + return y.node(); + }; + reformat[LayoutType::NHWC_TO_NCHW4] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3) / 4, cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); + return y1.node(); + }; + reformat[LayoutType::NHWC_TO_NCHW32] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3) / 32, cv(32)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); + return y1.node(); + }; + reformat[LayoutType::NHWC_TO_NCHW64] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); + return y1.node(); + }; auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { @@ -4095,20 +4260,37 @@ void PaddingChannelPass::apply(OptState& opt) const { size_t new_in_channels = new_inp[0]->shape()[1]; // pad input channels if (padding_oprs.count(opr->input(0)->owner_opr())) { - if (new_in_channels % 64 == 0) { - size_t pad_channels = new_in_channels - in_channels; - inps[1] = pad_in_channels(new_inp[1], pad_channels); + if (new_in_channels <= 32) { + if (new_in_channels % 8 == 0) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels_0 = 8 - (new_in_channels % 8); + size_t pad_channels_1 = 8 - (in_channels % 8); + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + } } else { - size_t pad_channels_0 = 64 - (new_in_channels % 64); - size_t pad_channels_1 = 64 - (in_channels % 64); - inps[0] = pad_in_channels(new_inp[0], pad_channels_0); - inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + if (new_in_channels % 64 == 0) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels_0 = 64 - (new_in_channels % 64); + size_t pad_channels_1 = 64 - (in_channels % 64); + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + } } } else { size_t pad_channels = 0; mgb_assert(new_in_channels == in_channels); - if (in_channels % 64) - pad_channels = 64 - (in_channels % 64); + if (in_channels <= 32) { + if (in_channels % 8) + pad_channels = 8 - (in_channels % 8); + } else { + if (in_channels % 64) + pad_channels = 64 - (in_channels % 64); + } if (pad_channels > 0) { inps[0] = pad_in_channels(new_inp[0], pad_channels); inps[1] = pad_in_channels(new_inp[1], pad_channels); @@ -4117,8 +4299,13 @@ void PaddingChannelPass::apply(OptState& opt) const { out_channels = inps[1]->shape()[0]; in_channels = inps[1]->shape()[1]; size_t pad_channels = 0; - if (out_channels % 64) - pad_channels = 64 - (out_channels % 64); + if (out_channels <= 32) { + if (out_channels % 8) + pad_channels = 8 - (out_channels % 8); + } else { + if (out_channels % 64) + pad_channels = 64 - (out_channels % 64); + } if (pad_channels > 0) { inps[1] = pad_out_channels(inps[1], pad_channels); inps[2] = pad_in_channels(inps[2], pad_channels); @@ -4402,20 +4589,16 @@ EnableNCHW64Pass::make_nchw64_converter() { return new_conv.node(); } }; - auto try_transform_to_nchw = - [&format_map]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) -> VarNode* { - mgb_assert(opr->input().size()==new_inp.size()); - bool check_dtype = - new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && - new_inp[1]->dtype().enumv() == DTypeEnum::Float32; + auto try_transform_to_nchw = + [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + mgb_assert(opr->input().size() == new_inp.size()); + bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && + new_inp[1]->dtype().enumv() == DTypeEnum::Float32; if (opr->input().size() >= 3) - check_dtype &= - new_inp[2]->dtype().enumv() == DTypeEnum::Float32; + check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32; if (opr->input().size() >= 4) - check_dtype &= - new_inp[3]->dtype().enumv() == DTypeEnum::Float32; + check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32; if (!check_dtype) return nullptr; auto inps = new_inp; @@ -4451,7 +4634,6 @@ EnableNCHW64Pass::make_nchw64_converter() { return ret->output()[0]; }; - auto try_transform_to_nchw4 = [make_new_conv, &format_map]( OperatorNodeBase* opr, diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 633ee347..97b45ecb 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -4735,6 +4735,101 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) { MGB_ASSERT_TENSOR_EQ(t1, t2); } +TEST(TestGoptInference, PaddingChannelsB4) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + REQUIRE_CUDA_COMPUTE_CAPABILITY(7, 5); + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {16, 3, 14, 14}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = opr::TypeCvt::make(y, dtype::Quantized4Asymm{20.f, 8}); + opr::Pooling::Param pool; + pool.format = opr::Pooling::Param::Format::NCHW; + y = opr::Pooling::make(y, pool); + auto w1 = mkcvar("w1", {48, 16, 3, 3}, dtype::QuantizedS4(1.234f)), + b1 = mkcvar("b1", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); + auto y1 = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); + auto w2 = mkcvar("w2", {48, 48, 3, 3}, dtype::QuantizedS4(1.234f)), + b2 = mkcvar("b2", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); + auto y2 = opr::ConvBias::make( + y1, w2, b2, param, {}, + OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); + auto w3 = mkcvar("w2", {16, 48, 3, 3}, dtype::QuantizedS4(1.234f)), + b3 = mkcvar("b2", {1, 16, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); + auto y3 = opr::ConvBias::make( + y2, w3, b3, param, {}, + OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); + using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; + auto y4 = opr::ElemwiseMultiType::make( + {y, y3}, {ElemMultiMode::QFUSE_ADD_RELU}, + OperatorNodeConfig{dtype::Quantized4Asymm{20.f, 7}}); + y4 = opr::TypeCvt::make(y4, dtype::Float32()); + SymbolVar y4_pad; + unpack_vector(gopt::GraphOptimizer{} + .add_pass() + .add_pass() + .apply({{y4}}) + .endpoint_vars(), + y4_pad); + ASSERT_EQ(y4_pad.node()->shape()[1], y4.node()->shape()[1]); + SmallVector oprs; + auto cb1 = [&oprs](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + oprs.push_back(opr); + } + }; + cg::DepOprIter{cb1}.add(y4_pad.node()->owner_opr()); + ASSERT_EQ(oprs.size(), 4); + ASSERT_EQ(oprs[0]->output(0)->shape()[1], 16); + ASSERT_EQ(oprs[1]->output(0)->shape()[1], 64); + ASSERT_EQ(oprs[2]->output(0)->shape()[1], 64); + ASSERT_EQ(oprs[3]->output(0)->shape()[1], 16); + size_t nr_concat = find_opr_num(y4_pad); + ASSERT_EQ(nr_concat, 1); + cg::OperatorNodeBase* concat = nullptr; + auto cb2 = [&concat](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + concat = opr; + } + }; + cg::DepOprIter{cb2}.add(y4_pad.node()->owner_opr()); + ASSERT_EQ(oprs[0]->input(0)->owner_opr(), concat); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y4, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y4_pad, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + #endif