diff --git a/src/gopt/impl/folding_conv_dimshuffle.cpp b/src/gopt/impl/folding_conv_dimshuffle.cpp index bb8a32ca..b8e3cb57 100644 --- a/src/gopt/impl/folding_conv_dimshuffle.cpp +++ b/src/gopt/impl/folding_conv_dimshuffle.cpp @@ -100,6 +100,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias = try_cast_as_op(shuffle->input(0)->owner_opr()); if (conv_bias == nullptr) return false; + bool is_group = + conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP; + if (is_group) + return false; inp_dtype = conv_bias->input(0)->dtype(); bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && @@ -180,6 +184,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias = try_cast_as_op(reshape2->input(0)->owner_opr()); if (conv_bias == nullptr) return false; + bool is_group = + conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP; + if (is_group) + return false; auto inp_dtype = conv_bias->input(0)->dtype(); bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && @@ -267,6 +275,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias = try_cast_as_op(shuffle->input(0)->owner_opr()); if (conv_bias == nullptr) return false; + bool is_group = + conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP; + if (is_group) + return false; auto inp_dtype = conv_bias->input(0)->dtype(); bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && @@ -345,6 +357,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias = try_cast_as_op(reshape2->input(0)->owner_opr()); if (conv_bias == nullptr) return false; + bool is_group = + conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP; + if (is_group) + return false; auto inp_dtype = conv_bias->input(0)->dtype(); bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index ec833cfd..60dc5547 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -198,6 +198,46 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var( return new_var; } +VarNode* EnableTensorCorePass::trans_to_nchw32(VarNode* new_inp) { + const TensorShape& shape = new_inp->shape(); + VarNode* node = new_inp; + //! nchw4 + if (shape.ndim == 5 && shape[4] == 4) { + node = RelayoutPlaceholder::make( + new_inp, + ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}) + .node(); + } else if (shape.ndim == 4) { + node = RelayoutPlaceholder::make( + new_inp, + ReformatKey{TensorFormats::NCHW, TensorFormats::NCHWc32}) + .node(); + } else { + mgb_assert(shape.ndim == 5 && shape[4] == 32); + } + return node; +} + +VarNode* EnableTensorCorePass::trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp) { + const TensorShape& shape = orig_inp->shape(); + VarNode* node = new_inp; + //! nchw4 + if (shape.ndim == 5 && shape[4] == 4) { + node = RelayoutPlaceholder::make( + new_inp, + ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}) + .node(); + } else if (shape.ndim == 4) { + node = RelayoutPlaceholder::make( + new_inp, + ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW}) + .node(); + } else { + mgb_assert(shape.ndim == 5 && shape[4] == 32); + } + return node; +} + std::unique_ptr EnableTensorCorePass:: make_tensorcore_converter() { MIDOUT_B("EnableTensorCorePass::make") @@ -231,6 +271,7 @@ std::unique_ptr EnableTensorCorePass:: "EnableTensorCorePass assumes that filter tensor of " "conv_bias operator can not be changed by other operators"); VarNode* orig_filter = opr->input(1); + auto is_nchw = [](TensorShape shape) -> bool { return shape.ndim == 4; }; auto is_nchw4 = [](TensorShape shape) -> bool { return shape.ndim == 5 && shape[4] == 4; }; @@ -259,10 +300,25 @@ std::unique_ptr EnableTensorCorePass:: // nchw32 layout need that input width and height are larger than 3 size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) { - auto symvar = RelayoutPlaceholder::make( - new_inp[0], - ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); - src = symvar.node(); + src = trans_to_nchw32(new_inp[0]); + can_replace_nchw32 = true; + } else { + src = new_inp[0]; + } + } else if (is_nchw(new_inp[0]->shape())) { + size_t group = 1, ocpg, icpg; + if (conv_bias.param().sparse == Sparse::DENSE) { + ocpg = orig_filter->shape()[0]; + icpg = orig_filter->shape()[1]; + } else { + mgb_assert(conv_bias.param().sparse == Sparse::GROUP); + group = orig_filter->shape()[0]; + icpg = orig_filter->shape()[2]; + ocpg = orig_filter->shape()[1]; + } + size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; + if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) { + src = trans_to_nchw32(new_inp[0]); can_replace_nchw32 = true; } else { src = new_inp[0]; @@ -287,18 +343,12 @@ std::unique_ptr EnableTensorCorePass:: can_replace_nchw32 = true; src = new_inp[0]; } else { - auto symvar = RelayoutPlaceholder::make( - new_inp[0], - ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}); - src = symvar.node(); + src = trans_from_nchw32(new_inp[0], opr->input(0)); } } // process filter tensor if (can_replace_nchw32) { - auto symvar = RelayoutPlaceholder::make( - new_inp[1], - ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); - weight = symvar.node(); + weight = trans_to_nchw32(new_inp[1]); } else { weight = new_inp[1]; } @@ -317,31 +367,20 @@ std::unique_ptr EnableTensorCorePass:: return new_opr; } } - auto process_inp = [&](VarNode* inp) -> VarNode* { + auto process_inp = [&](VarNode* inp, VarNode* orig) -> VarNode* { if (can_replace_nchw32) { - if (is_nchw4(inp->shape())) { - auto symvar = RelayoutPlaceholder::make( - inp, - ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); - return symvar.node(); - } else { - mgb_assert(is_nchw32(inp->shape())); - return inp; - } + return trans_to_nchw32(inp); } else { - if (is_nchw4(inp->shape())) { + if (is_nchw4(inp->shape()) || is_nchw(inp->shape())) { return inp; } else { mgb_assert(is_nchw32(inp->shape())); - auto symvar = RelayoutPlaceholder::make( - inp, - ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}); - return symvar.node(); + return trans_from_nchw32(inp, orig); } } }; // process bias tensor - bias = process_inp(new_inp[2]); + bias = process_inp(new_inp[2], opr->input(2)); if (new_inp.size() == 3) { if (can_replace_nchw32) { auto param = conv_bias.param(); @@ -358,7 +397,7 @@ std::unique_ptr EnableTensorCorePass:: } } // process z_inp tensor - z_inp = process_inp(new_inp[3]); + z_inp = process_inp(new_inp[3], opr->input(3)); if (can_replace_nchw32) { auto param = conv_bias.param(); param.format = Format::NCHW32; @@ -383,26 +422,25 @@ std::unique_ptr EnableTensorCorePass:: nr_shape_changed++; } } + + auto is_scalar = [](VarNode* inp) -> bool { + return inp->shape().ndim == 1 && inp->shape()[0] == 1; + }; + if (nr_shape_changed) { auto inps = new_inp; if (nr_shape_changed >= nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32 for (size_t i = 0; i < nr_inps; ++i) { if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], - ReformatKey{ - TensorFormats::NCHWc4, TensorFormats::NCHWc32}); - inps[i] = symvar.node(); + if (!is_scalar(new_inp[i])) { + inps[i] = trans_to_nchw32(new_inp[i]); + } } } } else { // NCHW32 < NCHW4 -> use NCHW4 for (size_t i = 0; i < nr_inps; ++i) { if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], - ReformatKey{ - TensorFormats::NCHWc32, TensorFormats::NCHWc4}); - inps[i] = symvar.node(); + inps[i] = trans_from_nchw32(new_inp[i], opr->input(i)); } } } @@ -410,7 +448,27 @@ std::unique_ptr EnableTensorCorePass:: } return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); }; - // for oprs only supports NCHW4 layout + + auto replace_inps_to_nchw = [](OperatorNodeBase* opr, const VarNodeArray new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray inps = new_inp; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + mgb_assert( + opr->input(i)->shape().ndim == 5 && + opr->input(i)->shape()[4] == 4); + mgb_assert( + new_inp[i]->shape().ndim == 5 && new_inp[i]->shape()[4] == 32); + auto symvar = RelayoutPlaceholder::make( + new_inp[i], + ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW}); + inps[i] = symvar.node(); + } + } + auto new_opr = serialization::copy_opr_shallow(*opr, inps, opr->config()); + return new_opr; + }; + auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr, const VarNodeArray new_inp) { mgb_assert(opr->input().size() == new_inp.size()); VarNodeArray inps = new_inp; @@ -446,49 +504,62 @@ std::unique_ptr EnableTensorCorePass:: "pass"); return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); }; - auto replace_warp_affine_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::WarpAffineForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_warp_perspective_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::WarpPerspectiveForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::ResizeForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& resize = opr->cast_final_safe(); - if (resize.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_pooling_opr = [replace_non_nchw4_opr]( + auto replace_warp_affine_opr = + [replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpAffineForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + if (warp.param().format == Format::NCHW) { + return replace_inps_to_nchw(opr, new_inp); + } + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_warp_perspective_opr = + [replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpPerspectiveForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + if (warp.param().format == Format::NCHW) { + return replace_inps_to_nchw(opr, new_inp); + } + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_resize_opr = + [replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::ResizeForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + + auto& resize = opr->cast_final_safe(); + if (resize.param().format == Format::NCHW) { + return replace_inps_to_nchw(opr, new_inp); + } + if (resize.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_pooling_opr = [replace_inps_to_nchw, replace_non_nchw4_opr]( OperatorNodeBase* opr, const VarNodeArray new_inp) { using Param = opr::PoolingForward::Param; using Format = Param::Format; mgb_assert(opr->input().size() == new_inp.size()); auto& pooling = opr->cast_final_safe(); + if (pooling.param().format == Format::NCHW) { + return replace_inps_to_nchw(opr, new_inp); + } if (pooling.param().format != Format::NCHW4) { return replace_non_nchw4_opr(opr, new_inp); } @@ -847,18 +918,23 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { ReformatKey weight_to_nchw4_mode_group{ TensorFormats::GKCRS, TensorFormats::GKCRSc4}; - struct ConvMode { + struct TransResult { ReformatKey weight; ReformatKey src; + bool can_trans; }; auto trans_nchw4 = [weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group, src_to_nchw4_mode]( const megdnn::param::Convolution::Sparse conv_mode, - const VarNode* filter) -> ConvMode { + const VarNode* filter) -> TransResult { if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { mgb_assert(filter->shape().ndim == 4, "The origin filter is not NCHW mode"); size_t IC = filter->shape()[1]; + size_t OC = filter->shape()[0]; + if (OC % 4 != 0) { + return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false}; + } if (IC < 4) { ReformatKey weight{ TensorFormats::KCRS, TensorFormats::KCRSc4, @@ -866,9 +942,11 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { ReformatKey src{ TensorFormats::NCHW, TensorFormats::NCHWc4, ReformatKey::Attribute::IC_SMALL}; - return {weight, src}; + return {weight, src, true}; + } else if (IC % 4 == 0) { + return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, true}; } else { - return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; + return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false}; } } else { mgb_throw_if( @@ -876,14 +954,31 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { MegBrainError, "mode error"); mgb_assert(filter->shape().ndim == 5, "The origin filter if not NCHW mode"); size_t IC = filter->shape()[2]; - mgb_assert( - IC % 4 == 0, - "The input channel should be divisible by 4 for group " - "conv"); - return {weight_to_nchw4_mode_group, src_to_nchw4_mode}; + size_t OC = filter->shape()[1]; + if (IC % 4 == 0 && OC % 4 == 0) { + return {weight_to_nchw4_mode_group, src_to_nchw4_mode, true}; + } else { + return {weight_to_nchw4_mode_group, src_to_nchw4_mode, false}; + } } }; - auto replace_conv_opr = [trans_nchw4, conv_format]( + + auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray temp_inp = new_inp; + for (size_t i = 0; i < opr->input().size(); i++) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + mgb_assert(opr->input(i)->shape().ndim == 4); + mgb_assert(new_inp[i]->shape().ndim == 5); + auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); + temp_inp[i] = new_var.node(); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); + }; + + auto replace_conv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { @@ -896,6 +991,11 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { } auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]); VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + + if (!conv_mode.can_trans) { + return relayout_inp_to_nchw(opr, new_inp); + } + // src: NCHW --> NCWH4 if (new_inp[0]->shape().ndim != 5) { mgb_assert(new_inp[0]->shape().ndim == 4); @@ -919,7 +1019,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { return new_opr; }; - auto replace_deconv_opr = [trans_nchw4, conv_format]( + auto replace_deconv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { @@ -933,6 +1033,11 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { } VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); + + if (!deconv_mode.can_trans) { + return relayout_inp_to_nchw(opr, new_inp); + } + // src: NCHW --> NCWH4 if (deconv_src->shape().ndim != 5) { mgb_assert(deconv_src->shape().ndim == 4); @@ -1027,7 +1132,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { dst.shape().ndim == 5, "The conv_bias dst dim is not trans to nchw4"); return new_opr; }; - auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, + auto replace_conv_bias_opr = [relayout_inp_to_nchw, trans_nchw4, conv_bias_format, conv_bias_format_nchw4_nchw, src_to_nchw4_mode]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { @@ -1043,6 +1148,11 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { // what should be converted: src, weight VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); + + if (!conv_mode.can_trans) { + return relayout_inp_to_nchw(opr, new_inp); + } + // src: NCHW --> NCHW4 if (new_inp[0]->shape().ndim != 5) { mgb_assert(new_inp[0]->shape().ndim == 4); @@ -1134,20 +1244,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); } }; - auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - VarNodeArray temp_inp = new_inp; - for (size_t i = 0; i < opr->input().size(); i++) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - mgb_assert(opr->input(i)->shape().ndim == 4); - mgb_assert(new_inp[i]->shape().ndim == 5); - auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); - temp_inp[i] = new_var.node(); - } - } - return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); - }; + auto replace_pooling_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); @@ -1257,6 +1354,7 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::AdaptivePoolingForward::typeinfo()] = relayout_inp_to_nchw; return ret; MIDOUT_E } diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 6c97db87..f6933ccd 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -235,6 +235,9 @@ public: class EnableTensorCorePass final : public TensorReformatPass { VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override; + static VarNode* trans_to_nchw32(VarNode* new_inp); + static VarNode* trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp); + public: const char* name() const override { return mgb_cstr_log("enable_tensorcore"); } //! make enable tensorcore opt pass diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 06eafe01..e6c76a77 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2356,8 +2356,18 @@ TEST(TestEnableTensorCore, SmallInputShape) { dtype); }; - auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)), - w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + auto x0 = mkvar("x0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)), + w0 = mkcvar("w0", {2, 32, 8, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b0 = mkcvar("b0", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), + z0 = mkcvar("z0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)); + opr::ConvBias::Param param0; + param0.format = opr::ConvBias::Param::Format::NCHW4; + param0.sparse = opr::ConvBias::Param::Sparse::GROUP; + param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param0.stride_h = param0.stride_w = 1; + param0.pad_h = param0.pad_w = 1; + + auto w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), z = mkcvar("b1", {32, 16, 2, 4, 4}, dtype::QuantizedS8(2.5f)); opr::ConvBias::Param param; @@ -2367,7 +2377,9 @@ TEST(TestEnableTensorCore, SmallInputShape) { param.pad_h = param.pad_w = 1; auto y = opr::ConvBias::make( - x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + x0, w0, b0, z0, param0, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = opr::ConvBias::make( + y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); y = opr::ConvBias::make( y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); y = opr::TypeCvt::make(y, dtype::Float32()); @@ -2431,10 +2443,47 @@ TEST(TestEnableTensorCore, Nchw4Nchw) { } }; + auto mk_flt_shape = [](opr::ConvBias::Param::Format format, size_t OC, size_t IC, + size_t FH, size_t FW, size_t g = 1) -> TensorShape { + mgb_assert(OC % (g * 4) == 0 && IC % (g * 4) == 0); + if (g == 1) { + if (format == opr::ConvBias::Param::Format::NCHW4) { + return {OC, IC / 4, FH, FW, 4}; + } else { + mgb_assert(format == opr::ConvBias::Param::Format::NCHW); + return {OC, IC, FH, FW}; + } + } else { + if (format == opr::ConvBias::Param::Format::NCHW4) { + return {g, OC / g, IC / 4 / g, FH, FW, 4}; + } else { + mgb_assert(format == opr::ConvBias::Param::Format::NCHW); + return {g, OC / g, IC / g, FH, FW}; + } + } + }; + for (auto format : {opr::ConvBias::Param::Format::NCHW, opr::ConvBias::Param::Format::NCHW4}) { - auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)), - w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), dtype::QuantizedS8(2.5f)), + auto x0 = mkvar( + "x0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)), + w0 = + mkcvar("w0", mk_flt_shape(format, 64, 64, 3, 3, 2), + dtype::QuantizedS8(2.5f)), + b0 = mkcvar( + "b0", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)), + z0 = mkcvar( + "z0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)); + opr::ConvBias::Param param0; + param0.format = format; + param0.sparse = opr::ConvBias::Param::Sparse::GROUP; + param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param0.stride_h = param0.stride_w = 1; + param0.pad_h = param0.pad_w = 1; + + auto w = + mkcvar("w1", mk_flt_shape(format, 64, 64, 3, 3), + dtype::QuantizedS8(2.5f)), b = mkcvar("b", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)), z = mkcvar("b1", mkshape(format, 32, 64, 8, 8), dtype::QuantizedS8(2.5f)); opr::ConvBias::Param param; @@ -2444,7 +2493,10 @@ TEST(TestEnableTensorCore, Nchw4Nchw) { param.pad_h = param.pad_w = 1; auto y = opr::ConvBias::make( - x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + x0, w0, b0, z0, param0, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y = opr::ConvBias::make( + y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); y = opr::ConvBias::make( y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); y = opr::TypeCvt::make(y, dtype::Float32()); @@ -2470,7 +2522,7 @@ TEST(TestEnableTensorCore, Nchw4Nchw) { ASSERT_EQ(2u, nr_dimshuffle); #endif } else { - ASSERT_EQ(2u, nr_dimshuffle); + ASSERT_EQ(3u, nr_dimshuffle); } std::string json_name; if (format == opr::ConvBias::Param::Format::NCHW4) {