diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp index 03c7aec4..7336d1b8 100644 --- a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( auto&& param = args.opr->param(); bool is_format_ok = param.format == param::ConvBias::Format::NCHW; bool is_version_ok = CUDNN_VERSION >= 7500; - bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; + bool is_dtype_ok = + (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && + (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || + args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); bool is_bias_ok = args.bias_layout->ndim == 0 || (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && diff --git a/dnn/src/cuda/conv_bias/cudnn_conv.cpp b/dnn/src/cuda/conv_bias/cudnn_conv.cpp index d6792fe5..9d4fa298 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv.cpp @@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( } } + if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || + args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { + return false; + } + // FIXME: cudnn cannot handle the case when the initial value of dst tensor // contains nan and beta is zero, because the result of 0.f * nan is still // nan diff --git a/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp b/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp index f7ec0259..5e5a5ad1 100644 --- a/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp +++ b/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp @@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( if (!is_compute_capability_required(6, 1)) return false; + if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || + args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { + return false; + } + auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); diff --git a/src/gopt/impl/folding_conv_typecvt.cpp b/src/gopt/impl/folding_conv_typecvt.cpp new file mode 100644 index 00000000..05277251 --- /dev/null +++ b/src/gopt/impl/folding_conv_typecvt.cpp @@ -0,0 +1,151 @@ +/** + * \file src/gopt/impl/folding_conv_typecvt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" + +#include "megdnn/opr_param_defs.h" + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "megbrain/utils/hash_ct.h" + +#include "midout.h" + +#include "megbrain/gopt/reformat_manager.h" + +#if CUDA_VERSION >= 10020 +MIDOUT_DECL(megbrain_folding_conv_typecvt) +#define MIDOUT_B(tag) \ + MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) { +#define MIDOUT_E \ + } \ + MIDOUT_END(); + +using namespace mgb; +using namespace gopt; +using ReformatKey = ReformatManager::ReformatKey; + +/* ==================== FoldingConvBiasTypecvtPass ================= */ +const char* FoldingConvBiasTypecvtPass::name() const { + return mgb_cstr_log("folding conv bias typecvt pass"); +} + +void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { + MIDOUT_B("FoldingConvBiasTypecvtPass::apply"); + using DepType = cg::OperatorNodeProp::DepType; + ThinHashMap>> + readers; + static const ThinHashSet opr_type_list = { + opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()}; + opt.graph().iter([&readers](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { + if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { + readers[i.first->owner_opr()].emplace_back(opr, i.second); + } + } + }); + + auto rewriter = opt.graph().make_rewriter(); + + auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) { + ThinHashSet opr_set; + ThinHashSet reader_set; + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + auto inp_dtype_typecvt = typecvt->input(0)->dtype(), + out_dtype_typecvt = typecvt->output(0)->dtype(); + bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && + out_dtype_typecvt.enumv() == DTypeEnum::Float32; + bool is_s82s4 = + inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && + (out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || + out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm); + bool is_s42s8 = + (inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || + inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) && + out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8; + + if (!(is_s82f32 || is_s82s4 || is_s42s8)) + return false; + opr_set.insert(opr); + + // check conv bias + auto conv_bias = + try_cast_as_op(typecvt->input(0)->owner_opr()); + if (conv_bias == nullptr) + return false; + auto inp_dtype_conv = conv_bias->input(0)->dtype(), + out_dtype_conv = conv_bias->input(0)->dtype(); + bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && + out_dtype_conv.enumv() == inp_dtype_conv.enumv() && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NHWC; + bool is_s4nhwc = + (inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 || + inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && + out_dtype_conv.enumv() == inp_dtype_conv.enumv() && + conv_bias->param().format == + megdnn::param::ConvBias::Format::NHWC; + if (!(is_s8nhwc || is_s4nhwc)) + return false; + if (conv_bias->input().size() != 3) + return false; + opr_set.insert(conv_bias); + for (auto&& i : readers[conv_bias]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } + for (auto reader : reader_set) { + if (opr_set.count(reader) <= 0) { + return false; + } + } + auto src = rewriter.get_var(conv_bias->input(0)), + filter = rewriter.get_var(conv_bias->input(1)), + bias = rewriter.get_var(conv_bias->input(2)); + auto new_bias = + (out_dtype_typecvt.enumv() == DTypeEnum::Float32) + ? opr::TypeCvt::make(bias, dtype::Float32()).node() + : bias; + auto new_param = conv_bias->param(); + new_param.format = megdnn::param::ConvBias::Format::NHWC; + auto conv_bias_typecvt = opr::ConvBias::make( + src, filter, new_bias, new_param, conv_bias->execution_policy(), + OperatorNodeConfig{out_dtype_typecvt}); + rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(), + mgb_cstr_log("replace conv_bias(NHWC) + typecvt " + "to conv_bias(NHWC)")); + return true; + }; + + auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) { + if (!try_conv_typecvt(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); + + MIDOUT_E +} +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index a9e166d3..e07b0bb9 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( add_pass(); #if CUDA_VERSION >= 10020 add_pass(); + add_pass(); #endif }); #undef cb diff --git a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp index aa25ba63..4554ae3b 100644 --- a/src/gopt/impl/global_layout_transform/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -57,7 +57,10 @@ std::unique_ptr make_cuda_ctx( TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; - Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA}; + + Attribute attribute = { + base_opr_format, base_tensor_format, Target::CUDA, + LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); @@ -67,8 +70,9 @@ std::unique_ptr make_cuda_ctx( OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) .add_opr_config(opr::ConvolutionForward::typeinfo(), {OprFormat::NCHW, OprFormat::NCHW4}) - .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config( + opr::ConvolutionBackwardData::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) .add_opr_config( opr::PoolingForward::typeinfo(), {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, diff --git a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index 23f4410c..8d9967ff 100644 --- a/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImplcast_final_safe(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); - config.opr_format = OprFormat::NCHW4; + config.opr_format = OprFormat::NHWC; bool available = true; for (size_t i = 0; i < opr->input().size(); ++i) { available &= diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 71ea13fc..5b817c81 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -481,6 +481,12 @@ namespace gopt { const char* name() const override; void apply(OptState& opt) const override; }; + + class FoldingConvBiasTypecvtPass final : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + }; #endif /*! diff --git a/src/gopt/test/layout_transform_pass.cpp b/src/gopt/test/layout_transform_pass.cpp index f7ef2f7a..c41af738 100644 --- a/src/gopt/test/layout_transform_pass.cpp +++ b/src/gopt/test/layout_transform_pass.cpp @@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) { using OprFormat = LayoutTransformContext::OprFormat; using OprList = LayoutTransformContext::OprList; using Attribute = LayoutTransformContext::Attribute; + using ReformatAttribute = LayoutTransformContext::ReformatAttribute; using Target = LayoutTransformContext::Target; OprList opr_list = { opr::ConvBiasForward::typeinfo(), @@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) { TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, - Target::UNSPEC}; + Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, + ReformatAttribute::AUTO_PADDING_NHWC}; auto ctx = std::make_unique( std::move(opr_list), std::move(available_tensor_formats), attribute); @@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) { OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) .add_opr_config(opr::ConvolutionForward::typeinfo(), {OprFormat::NCHW, OprFormat::NCHW4}) - .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), - {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config( + opr::ConvolutionBackwardData::typeinfo(), + {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) .add_opr_config( opr::PoolingForward::typeinfo(), {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, @@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) { .add_pass() .add_pass(FuseNCHW4Int8Preprocess::make()) .add_pass() + .add_pass() .add_pass() .add_pass() .apply(SymbolVarArray{y}) @@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) { /// check first conv format const auto& first_conv = find_opr(v); const auto& cast = first_conv.cast_final_safe(); - ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC); + ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC); + ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm); } #endif #endif