From ae6ff2c5a66023287f158690e654be82bf32c548 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 2 Apr 2021 17:54:40 +0800 Subject: [PATCH] feat(mgb/gopt): add opt pass for nchw64 layout transform GitOrigin-RevId: adede7cef6218b1a4871278bf45973ca599ad594 --- dnn/src/fallback/conv_bias/opr_impl.cpp | 3 +- dnn/src/fallback/convolution/opr_impl.cpp | 16 +- dnn/src/fallback/relayout/opr_impl.cpp | 7 + sdk/load-and-run/src/mgblar.cpp | 6 + src/core/impl/graph/operator_node.cpp | 1 - src/core/impl/tensor.cpp | 4 +- src/core/include/megbrain/graph/cg.h | 3 + src/gopt/impl/framework.cpp | 11 + src/gopt/impl/tensor_reformat.cpp | 1105 +++++++++++++++++++++++++--- src/gopt/include/megbrain/gopt/inference.h | 21 + src/gopt/test/inference.cpp | 329 ++++++++- src/opr/test/dnn/convolution.cpp | 4 +- 12 files changed, 1392 insertions(+), 118 deletions(-) diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index b60a2ac0..823157ff 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -342,7 +342,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44_DOT || - param().format == Param::Format::NCHW || + param().format == Param::Format::NCHW || + param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW64) { spatial_pos = 2; } else if (param().format == Param::Format::NHWC) { diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 8e1d1482..eb05cb14 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -481,7 +481,9 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, _megdnn_tensor_out grad, _megdnn_workspace workspace) { if (param().format == param::Convolution::Format::NHWCD4 || - param().format == param::Convolution::Format::NCHW4) { + param().format == param::Convolution::Format::NCHW4 || + (param().format == param::Convolution::Format::NCHW && + grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace); } @@ -493,7 +495,9 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) { if (param().format == param::Convolution::Format::NHWCD4 || - param().format == param::Convolution::Format::NCHW4) { + param().format == param::Convolution::Format::NCHW4 || + (param().format == param::Convolution::Format::NCHW && + grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes( filter, diff, grad); } @@ -506,7 +510,9 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) { if (param().format == param::Convolution::Format::NHWCD4 || - param().format == param::Convolution::Format::NCHW4) { + param().format == param::Convolution::Format::NCHW4 || + (param().format == param::Convolution::Format::NCHW && + grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_all_algorithms( filter, diff, grad); } @@ -523,7 +529,9 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { if (param().format == param::Convolution::Format::NHWCD4 || - param().format == param::Convolution::Format::NCHW4) { + param().format == param::Convolution::Format::NCHW4 || + (param().format == param::Convolution::Format::NCHW && + grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( filter, diff, grad, workspace_limit_in_bytes, positive_attr, negative_attr); diff --git a/dnn/src/fallback/relayout/opr_impl.cpp b/dnn/src/fallback/relayout/opr_impl.cpp index 56d20dae..011a4e2b 100644 --- a/dnn/src/fallback/relayout/opr_impl.cpp +++ b/dnn/src/fallback/relayout/opr_impl.cpp @@ -229,6 +229,13 @@ void RelayoutForwardImpl::exec( return; } + // FIXME: optimize for lowbit cases + if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || + src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + NaiveRelayoutForwardImpl::do_exec(src, dst); + return; + } + relayout::TransposeParam trans_param; bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 908bb756..46b6391a 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -230,6 +230,11 @@ R"__usage__( --enable-fuse-preprocess Fusion astype\pad_channel\dimshuffle and etc opr from h2d op )__usage__" +R"__usage__( + --enable-nchw64 + Execute operators with kernels implemented in MegDNN with NCHW64 tensor format. Can only be used + on Nvidia GPUs, which natively support fast int4 tensorcore inference. +)__usage__" ; struct DataParser { @@ -1150,6 +1155,7 @@ Args Args::from_argv(int argc, char **argv) { cb(nchw88); cb(nchw32); cb(nhwcd4); + cb(nchw64); #undef cb if (!strcmp(argv[i], "--enable-nchw44-dot")) { mgb_log_warn("enable-nchw44-dot optimization"); diff --git a/src/core/impl/graph/operator_node.cpp b/src/core/impl/graph/operator_node.cpp index 1197fd73..202e5876 100644 --- a/src/core/impl/graph/operator_node.cpp +++ b/src/core/impl/graph/operator_node.cpp @@ -351,7 +351,6 @@ void OperatorNodeBase::init_output_format() { } for (auto i : output()) { if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - mgb_assert(format.is_default()); i->format(TensorFormat(i->dtype())); } else { if (!format.is_default()) diff --git a/src/core/impl/tensor.cpp b/src/core/impl/tensor.cpp index a4f03517..ab5eafa4 100644 --- a/src/core/impl/tensor.cpp +++ b/src/core/impl/tensor.cpp @@ -643,8 +643,8 @@ TensorND::copy_from_fixlayout( (m_layout.format.is_lowbit_aligned() && m_layout.is_contiguous()), src_contig = src.layout().is_physical_contiguous() || - (m_layout.format.is_lowbit_aligned() && - m_layout.is_contiguous()); + (src.layout().format.is_lowbit_aligned() && + src.layout().is_contiguous()); if (self_contig && src_contig) { if ((m_layout.format.is_default() && src.layout().format.is_default()) || diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 72821a45..dd731284 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -112,6 +112,8 @@ struct GraphCommonOptimizeOptions { ///< tensorcore CHWN4, ///< compute using CHWN4 tensor format, transformed mainly ///< used for cuda + NCHW64, ///< compute using NCHW64 tensor format, used for fast int4 + ///< support on Nvidia GPU }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; @@ -154,6 +156,7 @@ struct GraphCommonOptimizeOptions { SET(nchw44_dot, NCHW44_DOT); SET(nchw32, NCHW32); SET(chwn4, CHWN4); + SET(nchw64, NCHW64); #undef SET }; diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 41654169..2eebbe63 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -774,6 +774,17 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( add_pass(); add_pass(); }); + cb(nchw64, { + add_pass(); + add_pass(); + add_pass(); + add_pass(EnableNCHW64Pass::make_nchw64_converter()); + add_pass(); + add_pass(); + add_pass(FuseNCHW4Int8Preprocess::make()); + add_pass(); + add_pass(); + }); cb(fuse_conv_bias_nonlinearity, { add_pass(); }); cb(fuse_conv_bias_with_z, { diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 8a198ab9..a0ece8fc 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -68,74 +68,79 @@ using namespace gopt; * oprs should not get involved in any actual computing. */ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, - cg::SingleCNOperatorNodeBase) // { + cg::SingleCNOperatorNodeBase) // { public: -//! relayout type of this opr -enum class LayoutType { - NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout - NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout - NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout - CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout - NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout - NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose - ///< channel size less than 4 - NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout - NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - - WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 - //!< layout - WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to - //!< nchw4 layout - WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout - //!< to nchw4 layout whose - //! channel size less than 4 - - WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 - //!< layout - WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to - //!< nchw88 layout - WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout - //!< to nchw88 layout - //!< the weight layout of input is nchw output is nchw88, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} - WEIGHT_HYBIRD_NCHW_NCHW88, - - WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 - //!< layout - WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to - //!< nchw44 layout - WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout - //!< to nchw44 layout - //!< the weight layout of input is nchw output is nchw44, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} - WEIGHT_HYBIRD_NCHW_NCHW44, - WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout dense - WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to - //!< NCHW44_DOT layout group -}; + //! relayout type of this opr + enum class LayoutType { + NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout + NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout + NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout + CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout + NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout + NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose + ///< channel size less than 4 + NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout + NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout + NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout + + WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 + //!< layout + WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to + //!< nchw4 layout + WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout + //!< to nchw4 layout whose + //! channel size less than 4 + + WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 + //!< layout + WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to + //!< nchw88 layout + WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout + //!< to nchw88 layout + //!< the weight layout of input is nchw output is nchw88, special for + //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} + WEIGHT_HYBIRD_NCHW_NCHW88, + + WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 + //!< layout + WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to + //!< nchw44 layout + WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout + //!< to nchw44 layout + //!< the weight layout of input is nchw output is nchw44, special for + //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} + WEIGHT_HYBIRD_NCHW_NCHW44, + WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to + //!< NCHW44_DOT layout dense + WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to + //!< NCHW44_DOT layout group + NCHW32_TO_NCHW, //! VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = + opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::NCHW32_TO_NCHW64] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 2, cv(2), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 2, sub(2), sub(3), sub(4) * 2}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW64_TO_NCHW] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = + opr::Concat::make({sub(0), sub(1) * 64, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::NCHW64_TO_NCHW4] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3), sub(4) / 4, cv(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) * 16, sub(2), sub(3), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW64_TO_NCHW32] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3), sub(4) / 32, cv(32)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) * 2, sub(2), sub(3), cv(32)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW_TO_NCHW64] = [](VarNode* inp) -> VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + }; + reformat[LayoutType::NCHW_TO_NCHW32] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + return y1.node(); + }; + reformat[LayoutType::NCHW4_TO_NCHW64] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 16, cv(16), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 16, sub(2), sub(3), sub(4) * 16}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { @@ -3152,7 +3336,8 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { void ShuffleShuffleRemovePass::Impl::do_replace() { auto rewriter = m_opt_state.graph().make_rewriter(); auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; - ThinHashMap var2endpoint; + ThinHashSet writers; + ThinHashSet root; ThinHashSet trt_opr_inps; SmallVector topo_order; @@ -3171,28 +3356,27 @@ void ShuffleShuffleRemovePass::Impl::do_replace() { for (auto&& opr : reverse_adaptor(topo_order)) { if (opr->same_type() || opr->same_type()) { - auto find = var2endpoint.find(opr->output(0)); - if (find != var2endpoint.end()) { - if (uniq_reader_check(opr->output(0))) { - var2endpoint[opr->input(0)] = find->second; - } else { - var2endpoint[opr->input(0)] = opr->output(0); + writers.insert(opr->input(0)->owner_opr()); + if (writers.count(opr) > 0) { + if (!uniq_reader_check(opr->output(0))) { + root.insert(opr); } } else { - var2endpoint[opr->input(0)] = opr->output(0); + root.insert(opr); } } } auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps, - &var2endpoint](OperatorNodeBase* opr) { + &root](OperatorNodeBase* opr) { MGB_MARK_USED_VAR(trt_opr_inps); bool cond_opr = opr->same_type() || opr->same_type(); if (cond_opr) { - bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0); - if (!cond_endpoint) + bool cond_endpoint = root.count(opr) > 0; + if (!cond_endpoint) { return; + } auto cur = opr; auto var = opr->output(0), inp_var = opr->input(0); bool force_folding_typecvt = false; @@ -3624,13 +3808,13 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { MIDOUT_E } -/* ==================== PaddingChaannelPass ================= */ +/* ==================== PaddingChannelPass ================= */ const char* PaddingChannelPass::name() const { return mgb_cstr_log("padding output channel to multiple of 4/32"); } void PaddingChannelPass::apply(OptState& opt) const { - MIDOUT_B("PaddingChannelPassPass::apply"); + MIDOUT_B("PaddingChannelPass::apply"); // do not check shape opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_SHAPE); @@ -3643,14 +3827,18 @@ void PaddingChannelPass::apply(OptState& opt) const { auto rewriter = opt.graph().make_rewriter(); auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS8 || + mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || + inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || + inp->dtype().enumv() == DTypeEnum::QuantizedS8 || inp->dtype().enumv() == DTypeEnum::QuantizedS32); TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; std::shared_ptr host_val = std::make_shared( inp->comp_node(), shape, inp->dtype()); auto ptr = host_val->raw_ptr(); - std::memset(ptr, 0, shape.total_nr_elems() * inp->dtype().size()); + size_t size_bytes = + TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); auto out = opr::Concat::make({inp, padding}, 1); @@ -3659,14 +3847,18 @@ void PaddingChannelPass::apply(OptState& opt) const { auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { mgb_assert(inp->shape().ndim == 4); - mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS8 || + mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || + inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || + inp->dtype().enumv() == DTypeEnum::QuantizedS8 || inp->dtype().enumv() == DTypeEnum::QuantizedS32); TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; std::shared_ptr host_val = std::make_shared( inp->comp_node(), shape, inp->dtype()); auto ptr = host_val->raw_ptr(); - std::memset(ptr, 0, shape.total_nr_elems() * inp->dtype().size()); + size_t size_bytes = + TensorLayout{shape, inp->dtype()}.span().dist_byte(); + std::memset(ptr, 0, size_bytes); auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); auto out = opr::Concat::make({inp, padding}, 0); @@ -3722,17 +3914,12 @@ void PaddingChannelPass::apply(OptState& opt) const { out_channels = inps[1]->shape()[0]; in_channels = inps[1]->shape()[1]; size_t pad_channels = 0; - if (in_channels <= 16) { + if (out_channels <= 16) { if (out_channels % 4) pad_channels = 4 - (out_channels % 4); } else { - if (out_channels <= 16) { - if (out_channels % 4) - pad_channels = 4 - (out_channels % 4); - } else { - if (out_channels % 32) - pad_channels = 32 - (out_channels % 32); - } + if (out_channels % 32) + pad_channels = 32 - (out_channels % 32); } if (pad_channels > 0) { inps[1] = pad_out_channels(inps[1], pad_channels); @@ -3751,6 +3938,40 @@ void PaddingChannelPass::apply(OptState& opt) const { mgb_assert(new_inp.size() == 3); mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); auto inps = new_inp; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + size_t new_in_channels = new_inp[0]->shape()[1]; + // pad input channels + if (padding_oprs.count(opr->input(0)->owner_opr())) { + if (new_in_channels % 64 == 0) { + size_t pad_channels = new_in_channels - in_channels; + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } else { + size_t pad_channels_0 = 64 - (new_in_channels % 64); + size_t pad_channels_1 = 64 - (in_channels % 64); + inps[0] = pad_in_channels(new_inp[0], pad_channels_0); + inps[1] = pad_in_channels(new_inp[1], pad_channels_1); + } + } else { + size_t pad_channels = 0; + mgb_assert(new_in_channels == in_channels); + if (in_channels % 64) + pad_channels = 64 - (in_channels % 64); + if (pad_channels > 0) { + inps[0] = pad_in_channels(new_inp[0], pad_channels); + inps[1] = pad_in_channels(new_inp[1], pad_channels); + } + } + out_channels = inps[1]->shape()[0]; + in_channels = inps[1]->shape()[1]; + size_t pad_channels = 0; + if (out_channels % 64) + pad_channels = 64 - (out_channels % 64); + if (pad_channels > 0) { + inps[1] = pad_out_channels(inps[1], pad_channels); + inps[2] = pad_in_channels(inps[2], pad_channels); + padding_oprs.insert(opr); + } return serialization::copy_opr_shallow(*opr, inps, opr->config()); }; @@ -3863,20 +4084,22 @@ void PaddingChannelPass::apply(OptState& opt) const { bool padding_all_inps = true; bool same_padding = true; size_t channels_after_padding = 0; + size_t i = 0; for (auto&& cur_inp : opr->input()) { bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; if (padding_cur_inp) { if (!have_padding_inp) have_padding_inp = true; if (channels_after_padding == 0) { - channels_after_padding = cur_inp->shape()[1]; + channels_after_padding = new_inp[i]->shape()[1]; } else { same_padding = - channels_after_padding == cur_inp->shape()[1]; + channels_after_padding == new_inp[i]->shape()[1]; } } if (padding_all_inps && (!padding_cur_inp || !same_padding)) padding_all_inps = false; + ++i; } if (have_padding_inp && !padding_all_inps) { auto inps = new_inp; @@ -3905,14 +4128,11 @@ void PaddingChannelPass::apply(OptState& opt) const { OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); - bool have_padding_inp = false; auto inps = new_inp; for (size_t i = 0; i < new_inp.size(); ++i) { auto cur_inp = opr->input(i); bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; if (padding_cur_inp) { - if (!have_padding_inp) - have_padding_inp = true; size_t orig_channels = cur_inp->shape()[1]; inps[i] = extract_subtensor(inps[i], orig_channels); } @@ -3964,4 +4184,697 @@ void PaddingChannelPass::apply(OptState& opt) const { MIDOUT_E } + +/* ================ EnableNCHW64Pass =============== */ +VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + auto iter = m_opr_format_map.find(orig_var->owner_opr()); + mgb_assert(iter != m_opr_format_map.end(), + "cannot find opr(type:%s,name:%s) information, related " + "output var node(name:%s)", + orig_var->owner_opr()->dyn_typeinfo()->name, + orig_var->owner_opr()->cname(), orig_var->cname()); + const auto& fmt = iter->second; + using LayoutType = RelayoutPlaceholder::LayoutType; + LayoutType type; + switch (fmt) { + case Format::NCHW4: + type = LayoutType::NCHW4_TO_NCHW; + break; + case Format::NCHW32: + type = LayoutType::NCHW32_TO_NCHW; + break; + case Format::NCHW64: + type = LayoutType::NCHW64_TO_NCHW; + break; + default: + mgb_throw(AssertionError, + "format(%d) is not supported, related var " + "node(name:%s)", + static_cast(fmt), orig_var->cname()); + }; + return RelayoutPlaceholder::make(new_var, type).node(); + } + return new_var; +} + +std::unique_ptr +EnableNCHW64Pass::make_nchw64_converter() { + MIDOUT_B("EnableNCHW64Pass::make") + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ + VarReplaceCheckFlag::CHECK_SHAPE); + auto& replace_func = ret->m_opr_replace_func; + auto& format_map = ret->m_opr_format_map; + auto make_new_conv = [](const VarNodeArray& inps, + const opr::ConvBiasForward* orig_conv, + Format format) { + auto param = orig_conv->param(); + // change format + param.format = format; + if (inps.size() == 2) { + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], param, orig_conv->execution_policy(), + orig_conv->config()); + return new_conv.node(); + } else if (inps.size() == 3) { + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], inps[2], param, + orig_conv->execution_policy(), orig_conv->config()); + return new_conv.node(); + } else { + mgb_assert(inps.size() == 4); + auto new_conv = opr::ConvBiasForward::make( + inps[0], inps[1], inps[2], inps[3], param, + orig_conv->execution_policy(), orig_conv->config()); + return new_conv.node(); + } + }; + + auto try_transform_to_nchw4 = + [make_new_conv, &format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + bool check_dtype = + opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && + opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; + if (opr->input().size() >= 3) + check_dtype &= + opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + if (opr->input().size() >= 4) + check_dtype &= + opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; + if (!check_dtype) + return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0; + mgb_assert(check_channels, + "invalid quantize conv bias opr(name:%s,oc:%zu,ic:%zu)", + opr->cname(), out_channels, in_channels); + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(opr->input(i)->owner_opr()); + if (iter == format_map.end()) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4); + return ovar.node(); + } else { + const auto& fmt = iter->second; + if (fmt == Format::NCHW4) { + return inps[i]; + } else if (fmt == Format::NCHW32) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + return ovar.node(); + } else { + mgb_assert(fmt == Format::NCHW64); + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW4); + return ovar.node(); + } + } + }; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + format_map.insert(std::make_pair(opr, Format::NCHW4)); + auto& conv_bias = opr->cast_final_safe(); + return make_new_conv(inps, &conv_bias, Format::NCHW4); + }; + + auto try_transform_to_nchw32 = + [make_new_conv, &format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + bool check_dtype = + opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 && + opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8; + if (opr->input().size() >= 3) + check_dtype &= + opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + if (opr->input().size() >= 4) + check_dtype &= + opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8; + if (!check_dtype) + return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + bool check_channels = out_channels % 32 == 0 && in_channels % 32 == 0; + if (!check_channels) + return nullptr; + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(opr->input(i)->owner_opr()); + if (iter == format_map.end()) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW32); + return ovar.node(); + } else { + const auto& fmt = iter->second; + if (fmt == Format::NCHW32) { + return inps[i]; + } else if (fmt == Format::NCHW4) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + return ovar.node(); + } else { + mgb_assert(fmt == Format::NCHW64); + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW32); + return ovar.node(); + } + } + }; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + format_map.insert(std::make_pair(opr, Format::NCHW32)); + auto& conv_bias = opr->cast_final_safe(); + return make_new_conv(inps, &conv_bias, Format::NCHW32); + }; + + auto try_transform_to_nchw64 = + [make_new_conv, &format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> VarNode* { + // fint4XWint4 and fuint4XWint4 + bool check_dtype = + (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || + opr->input(0)->dtype().enumv() == + DTypeEnum::Quantized4Asymm) && + opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS4; + if (opr->input().size() >= 3) + check_dtype &= + opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32; + if (opr->input().size() >= 4) + check_dtype &= opr->input(3)->dtype().enumv() == + opr->input(0)->dtype().enumv(); + if (!check_dtype) + return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; + size_t in_channels = opr->input(1)->shape()[1]; + bool check_channels = out_channels % 64 == 0 && in_channels % 64 == 0; + if (!check_channels) + return nullptr; + auto inps = new_inp; + auto process = [&](size_t i) -> VarNode* { + auto iter = format_map.find(opr->input(i)->owner_opr()); + if (iter == format_map.end()) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64); + return ovar.node(); + } else { + const auto& fmt = iter->second; + if (fmt == Format::NCHW64) { + return inps[i]; + } else if (fmt == Format::NCHW4) { + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); + return ovar.node(); + } else { + mgb_assert(fmt == Format::NCHW32); + auto ovar = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW64); + return ovar.node(); + } + } + }; + for (size_t i = 0; i < inps.size(); ++i) { + inps[i] = process(i); + } + format_map.insert(std::make_pair(opr, Format::NCHW64)); + auto& conv_bias = opr->cast_final_safe(); + return make_new_conv(inps, &conv_bias, Format::NCHW64); + }; + + // replace rule for conv bias opr + auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, + try_transform_to_nchw32, + try_transform_to_nchw64]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = megdnn::param::ConvBias; + using Sparse = Param::Sparse; + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias = opr->cast_final_safe(); + mgb_assert(conv_bias.param().sparse == Sparse::DENSE, + "only support dense conv now"); + VarNode* new_var = nullptr; + if ((new_var = try_transform_to_nchw32(opr, new_inp)) || + (new_var = try_transform_to_nchw4(opr, new_inp)) || + (new_var = try_transform_to_nchw64(opr, new_inp))) { + return new_var->owner_opr(); + } else { + mgb_assert( + opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && + opr->input(0)->dtype().enumv() != + DTypeEnum::QuantizedS4 && + opr->input(0)->dtype().enumv() != + DTypeEnum::Quantized4Asymm, + "invalid data type(%s)", opr->input(0)->dtype().name()); + bool shape_changed = false; + for (const auto& i : new_inp) { + if (format_map.count(i->owner_opr()) > 0) { + shape_changed = true; + break; + } + } + mgb_assert(!shape_changed, + "EnableNCHW64Pass won't change format of output tensor " + "of non quantized conv bias operator(name:%s)", + opr->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr; + replace_func[opr::ConvolutionBackwardData:: + typeinfo()] = [&format_map](OperatorNodeBase* opr, + const VarNodeArray& + new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + mgb_assert(new_inp.size() == 2, + "deconv (conv bwd data) operator for inference can " + "only have 2 input vars(got:%zu)", + new_inp.size()); + auto& deconv = opr->cast_final_safe(); + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + Format cur; + auto iter = format_map.find(opr->input(1)->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + auto inps = new_inp; + inps[0] = RelayoutPlaceholder::make( + inps[0], + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) + .node(); + switch (cur) { + case Format::NCHW: + inps[1] = RelayoutPlaceholder::make( + inps[1], RelayoutPlaceholder::LayoutType:: + NCHW_TO_NCHW4) + .node(); + break; + case Format::NCHW32: + inps[1] = RelayoutPlaceholder::make( + inps[1], RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW4) + .node(); + break; + case Format::NCHW64: + inps[1] = RelayoutPlaceholder::make( + inps[1], RelayoutPlaceholder::LayoutType:: + NCHW64_TO_NCHW4) + .node(); + break; + default: + mgb_assert(cur == Format::NCHW4); + } + format_map.insert(std::make_pair(opr, Format::NCHW4)); + auto param = deconv.param(); + param.format = Format::NCHW4; + auto new_deconv = opr::ConvolutionBackwardData::make( + inps[0], inps[1], param, deconv.execution_policy(), + deconv.config()); + return new_deconv.node()->owner_opr(); + } else { + bool shape_changed = false; + for (const auto& i : new_inp) { + if (format_map.count(i->owner_opr()) > 0) { + shape_changed = true; + break; + } + } + mgb_assert(!shape_changed, + "EnableNCHW64Pass won't change format of output tensor " + "of non quantized deconv operator(name:%s)", + opr->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + + // replace rule for elemwise like opr + auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + ThinHashMap format_size; + bool same_format = true; + bool first_touch = false; + Format format; + for (const auto& i : opr->input()) { + Format cur; + auto iter = format_map.find(i->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + auto& size = format_size[cur]; + size += i->shape().total_nr_elems(); + if (!first_touch) { + first_touch = true; + format = cur; + } else { + if (format != cur) + same_format = false; + } + } + if (same_format) { + if (format != Format::NCHW) + format_map.insert(std::make_pair(opr, format)); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + + Format max_format; + size_t max_size = std::numeric_limits::min(); + for (const auto& item : format_size) { + if (item.second > max_size) { + max_format = item.first; + max_size = item.second; + } + } + static const ThinHashMap, + thin_function> + map = { +#define cb(_fmt1, _fmt2) \ + { \ + std::make_pair(Format::_fmt1, Format::_fmt2), \ + [](VarNode* in) -> VarNode* { \ + return RelayoutPlaceholder::make( \ + in, RelayoutPlaceholder::LayoutType:: \ + _fmt1##_TO_##_fmt2) \ + .node(); \ + } \ + } + cb(NCHW, NCHW4), cb(NCHW, NCHW32), cb(NCHW, NCHW64), + cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64), + cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), + cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), +#undef cb + }; + auto inps = new_inp; + for (size_t i = 0; i < opr->input().size(); ++i) { + auto iter = format_map.find(opr->input(i)->owner_opr()); + Format cur; + if (iter != format_map.end()) { + cur = iter->second; + } else { + cur = Format::NCHW; + } + if (cur != max_format) { + inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]); + } + } + if (max_format != Format::NCHW) + format_map.insert(std::make_pair(opr, max_format)); + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + // elemwise like + replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = + replace_elemwise_like_opr; + replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr; + + auto replace_warp_perspective_opr = [&format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || + opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { + Format cur; + auto iter = format_map.find(opr->input(0)->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + auto inps = new_inp; + switch (cur) { + case Format::NCHW: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW_TO_NCHW64) + .node(); + break; + case Format::NCHW4: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW4_TO_NCHW64) + .node(); + break; + case Format::NCHW32: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW64) + .node(); + break; + default: + mgb_assert(cur == Format::NCHW64); + } + format_map.insert(std::make_pair(opr, Format::NCHW64)); + auto param = warp.param(); + param.format = Format::NCHW64; + SymbolVar new_warp; + if (inps.size() == 3) { + new_warp = opr::WarpPerspectiveForward::make( + inps[0], inps[1], inps[2], param, + warp.config()); + } else { + mgb_assert(inps.size() == 4); + new_warp = opr::WarpPerspectiveForward::make( + inps[0], inps[1], inps[2], inps[3], param, + warp.config()); + } + return new_warp.node()->owner_opr(); + } else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + Format cur; + auto iter = format_map.find(opr->input(0)->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + auto inps = new_inp; + switch (cur) { + case Format::NCHW: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW_TO_NCHW4) + .node(); + break; + case Format::NCHW32: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW4) + .node(); + break; + case Format::NCHW64: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW64_TO_NCHW4) + .node(); + break; + default: + mgb_assert(cur == Format::NCHW4); + } + format_map.insert(std::make_pair(opr, Format::NCHW4)); + auto param = warp.param(); + param.format = Format::NCHW4; + SymbolVar new_warp; + if (inps.size() == 3) { + new_warp = opr::WarpPerspectiveForward::make( + inps[0], inps[1], inps[2], param, + warp.config()); + } else { + mgb_assert(inps.size() == 4); + new_warp = opr::WarpPerspectiveForward::make( + inps[0], inps[1], inps[2], inps[3], param, + warp.config()); + } + return new_warp.node()->owner_opr(); + } else { + bool shape_changed = false; + for (const auto& i : new_inp) { + if (format_map.count(i->owner_opr()) > 0) { + shape_changed = true; + break; + } + } + mgb_assert(!shape_changed, + "EnableNCHW64Pass won't change format of output tensor " + "of non quantized warp perspective operator(name:%s)", + opr->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + auto replace_pooling_opr = [&format_map]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling = opr->cast_final_safe(); + if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || + opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { + Format cur; + auto iter = format_map.find(opr->input(0)->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + auto inps = new_inp; + switch (cur) { + case Format::NCHW: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW_TO_NCHW64) + .node(); + break; + case Format::NCHW4: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW4_TO_NCHW64) + .node(); + break; + case Format::NCHW32: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW64) + .node(); + break; + default: + mgb_assert(cur == Format::NCHW64); + } + format_map.insert(std::make_pair(opr, Format::NCHW64)); + auto param = pooling.param(); + param.format = Format::NCHW64; + auto new_pool = + opr::PoolingForward::make(inps[0], param, pooling.config()); + return new_pool.node()->owner_opr(); + } else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { + Format cur; + auto iter = format_map.find(opr->input(0)->owner_opr()); + if (iter == format_map.end()) { + cur = Format::NCHW; + } else { + cur = iter->second; + } + size_t in_channels = opr->input(0)->shape()[1]; + bool use_nchw32 = false; + auto inps = new_inp; + using LayoutType = RelayoutPlaceholder::LayoutType; + switch (cur) { + case Format::NCHW: { + use_nchw32 = in_channels % 32 == 0; + auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32 + : LayoutType::NCHW_TO_NCHW4; + inps[0] = RelayoutPlaceholder::make(inps[0], layout_type) + .node(); + break; + } + case Format::NCHW64: + inps[0] = RelayoutPlaceholder::make( + inps[0], RelayoutPlaceholder::LayoutType:: + NCHW64_TO_NCHW32) + .node(); + break; + case Format::NCHW32: + use_nchw32 = true; + break; + default: + mgb_assert(cur == Format::NCHW4); + } + Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4; + format_map.insert(std::make_pair(opr, out_format)); + auto param = pooling.param(); + param.format = out_format; + auto new_pool = + opr::PoolingForward::make(inps[0], param, pooling.config()); + return new_pool.node()->owner_opr(); + } else { + bool shape_changed = false; + for (const auto& i : new_inp) { + if (format_map.count(i->owner_opr()) > 0) { + shape_changed = true; + break; + } + } + mgb_assert(!shape_changed, + "EnableNCHW64Pass won't change format of output tensor " + "of non quantized pooling operator(name:%s)", + opr->cname()); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + // format aware + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + replace_warp_perspective_opr; + replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + + // to nchw + auto replace_inps_to_nchw = [&format_map](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto inps = new_inp; + for (size_t i = 0; i < opr->input().size(); ++i) { + auto iter = format_map.find(opr->input(i)->owner_opr()); + if (iter != format_map.end()) { + auto fmt = iter->second; + switch (fmt) { + case Format::NCHW4: + inps[i] = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType:: + NCHW4_TO_NCHW) + .node(); + break; + case Format::NCHW32: + inps[i] = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW) + .node(); + break; + default: + mgb_assert(fmt == Format::NCHW64); + inps[i] = RelayoutPlaceholder::make( + inps[i], + RelayoutPlaceholder::LayoutType:: + NCHW64_TO_NCHW) + .node(); + break; + } + } + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + }; + + replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw; + replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw; + replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw; + replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw; + replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw; + replace_func[opr::Subtensor::typeinfo()] = replace_inps_to_nchw; + return ret; + MIDOUT_E +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 1ff3701e..4a31372f 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -419,6 +419,27 @@ namespace gopt { void apply(OptState& opt) const override; }; + /*! + * \brief convert tensor format to nchw64 to enable tensorcore int4 on CUDA + * we assume that the input network is in NCHW layout + */ + class EnableNCHW64Pass final : public TensorReformatPass { + public: + using Format = opr::ConvBias::Param::Format; + const char* name() const override { + return mgb_cstr_log("tensor_format_nchw64"); + } + + //! make nchw -> nchw64 converter opt pass + static std::unique_ptr make_nchw64_converter(); + + private: + ThinHashMap m_opr_format_map; + + VarNode* on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const override; + }; + } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 312594f1..2483a2ee 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -4248,7 +4248,7 @@ TEST(TestGoptInference, PaddingChannels) { }; cg::DepOprIter{cb}.add(y3_pad.node()->owner_opr()); ASSERT_EQ(oprs.size(), 3); - ASSERT_EQ(oprs[0]->output(0)->shape()[1], 20); + ASSERT_EQ(oprs[0]->output(0)->shape()[1], 32); ASSERT_EQ(oprs[1]->output(0)->shape()[1], 32); ASSERT_EQ(oprs[2]->output(0)->shape()[1], 32); HostTensorND t1, t2; @@ -4322,7 +4322,7 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) { }; cg::DepOprIter{cb}.add(y2_pad.node()->owner_opr()); ASSERT_EQ(oprs.size(), 2); - ASSERT_EQ(oprs[0]->output(0)->shape()[1], 20); + ASSERT_EQ(oprs[0]->output(0)->shape()[1], 32); ASSERT_EQ(oprs[1]->output(0)->shape()[1], 32); HostTensorND t1, t2; auto func1 = graph->compile({make_callback_copy(y2, t1)}); @@ -4335,16 +4335,16 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) { // FIXME replace cpu with gpu to enable gpu validation TEST(TestGoptInference, PaddingChannelsWithPooling) { REQUIRE_GPU(1); - auto cn = CompNode::load("cpu0"); -// cn.activate(); -// auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; -// auto sm_ver = prop.major * 10 + prop.minor; -// if (sm_ver < 61) { -// printf("This testcast ignored due to insufficient cuda cap(got: %d, " -// "expected: %d)\n", -// sm_ver, 61); -// return; -// } + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 61) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 61); + return; + } HostTensorGenerator gen; auto graph = ComputingGraph::make(); @@ -4485,6 +4485,311 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) { func2->execute(); MGB_ASSERT_TENSOR_EQ(t1, t2); } + +TEST(TestGoptInference, EnableNCHW64Basic) { + REQUIRE_GPU(1); + auto cn = CompNode::load("cpu0"); +// cn.activate(); +// auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; +// auto sm_ver = prop.major * 10 + prop.minor; +// if (sm_ver < 61) { +// printf("This testcast ignored due to insufficient cuda cap(got: %d, " +// "expected: %d)\n", +// sm_ver, 61); +// return; +// } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {16, 4, 14, 14}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {32, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + auto w1 = mkcvar("w1", {32, 32, 3, 3}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y1 = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + auto w2 = mkcvar("w2", {64, 32, 3, 3}, dtype::QuantizedS8(2.5f)), + b2 = mkcvar("b2", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y2 = opr::ConvBias::make(y1, w2, b2, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y2 = opr::TypeCvt::make(y2, dtype::QuantizedS4{40.f}); + auto w3 = mkcvar("w3", {64, 64, 3, 3}, dtype::QuantizedS4(2.5f)), + b3 = mkcvar("b3", {1, 64, 1, 1}, dtype::QuantizedS32(100.f)); + auto y3 = opr::ConvBias::make(y2, w3, b3, param, {}, + OperatorNodeConfig{dtype::QuantizedS4{40.f}}); + y3 = opr::TypeCvt::make(y3, dtype::QuantizedS8{2.5f}); + auto w4 = mkcvar("w4", {32, 64, 3, 3}, dtype::QuantizedS8(2.5f)), + b4 = mkcvar("b4", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y4 = opr::ConvBias::make(y3, w4, b4, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; + auto y5 = opr::ElemwiseMultiType::make( + {y, y4}, {ElemMultiMode::QFUSE_ADD_RELU}, + OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); + y5 = opr::TypeCvt::make(y5, dtype::Float32()); + SymbolVar y5_pad; + unpack_vector( + gopt::GraphOptimizer{} + .add_pass(gopt::EnableNCHW64Pass::make_nchw64_converter()) + .apply({{y5}}) + .endpoint_vars(), + y5_pad); + EXPECT_TRUE(y5.node()->shape().eq_shape(y5_pad.node()->shape())); + SmallVector oprs; + auto cb = [&oprs](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + oprs.push_back(opr); + } + }; + cg::DepOprIter{cb}.add(y5_pad.node()->owner_opr()); + ASSERT_EQ(oprs.size(), 5); + using Format = opr::ConvBiasForward::Param::Format; +#define CHECK(_i, _fmt) \ + { \ + const auto& o = oprs[_i]->cast_final(); \ + ASSERT_EQ(o.param().format, Format::_fmt); \ + } + CHECK(0, NCHW4); + CHECK(1, NCHW32); + CHECK(2, NCHW32); + CHECK(3, NCHW64); + CHECK(4, NCHW32); +#undef CHECK + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y5, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y5_pad, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestGoptInference, EnableNCHW64PaddingChannel) { + REQUIRE_GPU(1); + auto cn = CompNode::load("cpu0"); +// cn.activate(); +// auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; +// auto sm_ver = prop.major * 10 + prop.minor; +// if (sm_ver < 61) { +// printf("This testcast ignored due to insufficient cuda cap(got: %d, " +// "expected: %d)\n", +// sm_ver, 61); +// return; +// } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {16, 3, 14, 14}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {20, 3, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 20, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + opr::Pooling::Param pool; + pool.format = opr::Pooling::Param::Format::NCHW; + y = opr::Pooling::make(y, pool); + + auto w1 = mkcvar("w1", {24, 20, 3, 3}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 24, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y1 = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + auto w2 = mkcvar("w2", {20, 24, 3, 3}, dtype::QuantizedS8(2.5f)), + b2 = mkcvar("b2", {1, 20, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y2 = opr::ConvBias::make(y1, w2, b2, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y2 = opr::TypeCvt::make(y2, dtype::QuantizedS4{40.f}); + auto w3 = mkcvar("w3", {64, 20, 3, 3}, dtype::QuantizedS4(2.5f)), + b3 = mkcvar("b3", {1, 64, 1, 1}, dtype::QuantizedS32(100.f)); + auto y3 = opr::ConvBias::make(y2, w3, b3, param, {}, + OperatorNodeConfig{dtype::QuantizedS4{40.f}}); + auto w4 = mkcvar("w4", {20, 64, 3, 3}, dtype::QuantizedS4(2.5f)), + b4 = mkcvar("b4", {1, 20, 1, 1}, dtype::QuantizedS32(100.f)); + auto y4 = opr::ConvBias::make(y3, w4, b4, param, {}, + OperatorNodeConfig{dtype::QuantizedS4{40.f}}); + y4 = opr::TypeCvt::make(y4, dtype::QuantizedS8{2.5f}); + using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; + auto y5 = opr::ElemwiseMultiType::make( + {y, y4}, {ElemMultiMode::QFUSE_ADD_RELU}, + OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); + opr::ConvolutionBackwardData::Param deconv; + deconv.format = opr::ConvolutionBackwardData::Param::Format::NCHW; + deconv.stride_h = deconv.stride_w = 2; + deconv.pad_h = deconv.pad_w = 1; + auto w6 = mkcvar("w6", {20, 20, 4, 4}, dtype::QuantizedS8{2.5f}); + auto y6 = opr::ConvolutionBackwardData::make( + w6, y5, deconv, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.0f)}); + + std::shared_ptr mat = std::make_shared( + cn, TensorShape{16, 3, 3}, dtype::Float32()); + warp_perspective_mat_gen(*mat, 16, 14, 14); + auto mat_var = opr::Host2DeviceCopy::make(*graph, mat).rename("mat"); + opr::WarpPerspective::Param warp_param; + warp_param.format = opr::WarpPerspective::Param::Format::NCHW; + auto y7 = opr::WarpPerspective::make(y6, mat_var, TensorShape{14, 14}, + warp_param); + y7 = opr::TypeCvt::make(y7, dtype::Float32()); + SymbolVar y7_pad; + auto opt = gopt::OptimizeForInferenceOptions{}; + opt.enable_nchw64(); + unpack_vector(gopt::optimize_for_inference({y7}, opt), y7_pad); + EXPECT_TRUE(y7.node()->shape().eq_shape(y7_pad.node()->shape())); + SmallVector oprs; + auto cb = [&oprs](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + oprs.push_back(opr); + } + }; + cg::DepOprIter{cb}.add(y7_pad.node()->owner_opr()); + ASSERT_EQ(oprs.size(), 5); + using Format = opr::ConvBiasForward::Param::Format; +#define CHECK(_i, _fmt) \ + { \ + const auto& o = oprs[_i]->cast_final(); \ + ASSERT_EQ(o.param().format, Format::_fmt); \ + } + CHECK(0, NCHW4); + CHECK(1, NCHW32); + CHECK(2, NCHW32); + CHECK(3, NCHW64); + CHECK(4, NCHW64); +#undef CHECK + { + const auto& deconv = find_opr(y7_pad); + ASSERT_EQ(deconv.param().format, Format::NCHW4); + const auto& pool = find_opr(y7_pad); + ASSERT_EQ(pool.param().format, Format::NCHW4); + const auto& warp = find_opr(y7_pad); + ASSERT_EQ(warp.param().format, Format::NCHW4); + } + size_t nr_dimshuffle = find_opr_num(y7_pad); + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y7, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y7_pad, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +} + +TEST(TestGoptInference, EnableNCHW64FuseConvBiasZ) { + REQUIRE_GPU(1); + auto cn = CompNode::load("cpu0"); +// cn.activate(); +// auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; +// auto sm_ver = prop.major * 10 + prop.minor; +// if (sm_ver < 61) { +// printf("This testcast ignored due to insufficient cuda cap(got: %d, " +// "expected: %d)\n", +// sm_ver, 61); +// return; +// } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {16, 4, 14, 14}, dtype::QuantizedS8(2.5f)), + w = mkcvar("w", {32, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; + + auto y = opr::ConvBias::make(x, w, b, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + auto w1 = mkcvar("w1", {64, 32, 3, 3}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)); + auto y1 = opr::ConvBias::make(y, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + y1 = opr::TypeCvt::make(y1, dtype::QuantizedS4{40.f}); + auto w2 = mkcvar("w2", {64, 64, 3, 3}, dtype::QuantizedS4(2.5f)), + b2 = mkcvar("b2", {1, 64, 1, 1}, dtype::QuantizedS32(100.f)); + auto y2 = opr::ConvBias::make(y1, w2, b2, param, {}, + OperatorNodeConfig{dtype::QuantizedS4{40.f}}); + auto w3 = mkcvar("w3", {64, 64, 3, 3}, dtype::QuantizedS4(2.5f)), + b3 = mkcvar("b3", {1, 64, 1, 1}, dtype::QuantizedS32(100.f)); + auto y3 = opr::ConvBias::make(y2, w3, b3, param, {}, + OperatorNodeConfig{dtype::QuantizedS4(40.f)}); + using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; + auto y4 = opr::ElemwiseMultiType::make( + {y1, y3}, {ElemMultiMode::QFUSE_ADD_RELU}, + OperatorNodeConfig{dtype::QuantizedS4{40.f}}); + y4 = opr::TypeCvt::make(y4, dtype::Float32()); + SymbolVar y4_pad; + auto opt = gopt::OptimizeForInferenceOptions{}; + opt.enable_nchw64(); + unpack_vector(gopt::optimize_for_inference({y4}, opt), y4_pad); + EXPECT_TRUE(y4.node()->shape().eq_shape(y4_pad.node()->shape())); + size_t nr_elem_mult_type = find_opr_num(y4_pad); + ASSERT_EQ(nr_elem_mult_type, 0); + // FIXME need impl of elemwise/elemwise_multi_type on CUDA +#if 0 + HostTensorND t1, t2; + auto func1 = graph->compile({make_callback_copy(y4, t1)}); + func1->execute(); + auto func2 = graph->compile({make_callback_copy(y4_pad, t2)}); + func2->execute(); + MGB_ASSERT_TENSOR_EQ(t1, t2); +#endif +} + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 880ea6a5..12e6b182 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -2642,7 +2642,7 @@ TEST(TestOprDNN, ConvBiasInt4NCHW) { cn); opr::ConvBias::Param param; param.format = opr::ConvBias::Param::Format::NCHW; - param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; param.stride_h = param.stride_w = S; param.pad_h = param.pad_w = P; Policy policy; @@ -2719,7 +2719,7 @@ TEST(TestOprDNN, ConvBiasInt4NCHW64) { cn); opr::ConvBias::Param param; param.format = opr::ConvBias::Param::Format::NCHW64; - param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; param.stride_h = param.stride_w = S; param.pad_h = param.pad_w = P; Policy policy;