GitOrigin-RevId: adc2301203
release-1.7
@@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | ||||
bool is_version_ok = CUDNN_VERSION >= 7500; | bool is_version_ok = CUDNN_VERSION >= 7500; | ||||
bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; | |||||
bool is_dtype_ok = | |||||
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||||
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | |||||
bool is_bias_ok = | bool is_bias_ok = | ||||
args.bias_layout->ndim == 0 || | args.bias_layout->ndim == 0 || | ||||
(args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | ||||
@@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||||
} | } | ||||
} | } | ||||
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
return false; | |||||
} | |||||
// FIXME: cudnn cannot handle the case when the initial value of dst tensor | // FIXME: cudnn cannot handle the case when the initial value of dst tensor | ||||
// contains nan and beta is zero, because the result of 0.f * nan is still | // contains nan and beta is zero, because the result of 0.f * nan is still | ||||
// nan | // nan | ||||
@@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( | |||||
if (!is_compute_capability_required(6, 1)) | if (!is_compute_capability_required(6, 1)) | ||||
return false; | return false; | ||||
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||||
args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
return false; | |||||
} | |||||
auto dst_layout = *args.dst_layout; | auto dst_layout = *args.dst_layout; | ||||
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
dst_layout.dtype = DType(); | dst_layout.dtype = DType(); | ||||
@@ -0,0 +1,151 @@ | |||||
/** | |||||
* \file src/gopt/impl/folding_conv_typecvt.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "megbrain/serialization/opr_shallow_copy.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/utils/hash_ct.h" | |||||
#include "midout.h" | |||||
#include "megbrain/gopt/reformat_manager.h" | |||||
#if CUDA_VERSION >= 10020 | |||||
MIDOUT_DECL(megbrain_folding_conv_typecvt) | |||||
#define MIDOUT_B(tag) \ | |||||
MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) { | |||||
#define MIDOUT_E \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
/* ==================== FoldingConvBiasTypecvtPass ================= */ | |||||
const char* FoldingConvBiasTypecvtPass::name() const { | |||||
return mgb_cstr_log("folding conv bias typecvt pass"); | |||||
} | |||||
void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||||
MIDOUT_B("FoldingConvBiasTypecvtPass::apply"); | |||||
using DepType = cg::OperatorNodeProp::DepType; | |||||
ThinHashMap<OperatorNodeBase*, | |||||
SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||||
readers; | |||||
static const ThinHashSet<Typeinfo*> opr_type_list = { | |||||
opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()}; | |||||
opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||||
for (auto&& i : opr->node_prop().dep_map()) { | |||||
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||||
readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||||
} | |||||
} | |||||
}); | |||||
auto rewriter = opt.graph().make_rewriter(); | |||||
auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check typecvt | |||||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
auto inp_dtype_typecvt = typecvt->input(0)->dtype(), | |||||
out_dtype_typecvt = typecvt->output(0)->dtype(); | |||||
bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||||
out_dtype_typecvt.enumv() == DTypeEnum::Float32; | |||||
bool is_s82s4 = | |||||
inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||||
(out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||||
out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm); | |||||
bool is_s42s8 = | |||||
(inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||||
inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8; | |||||
if (!(is_s82f32 || is_s82s4 || is_s42s8)) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
auto inp_dtype_conv = conv_bias->input(0)->dtype(), | |||||
out_dtype_conv = conv_bias->input(0)->dtype(); | |||||
bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | |||||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NHWC; | |||||
bool is_s4nhwc = | |||||
(inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 || | |||||
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NHWC; | |||||
if (!(is_s8nhwc || is_s4nhwc)) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = | |||||
(out_dtype_typecvt.enumv() == DTypeEnum::Float32) | |||||
? opr::TypeCvt::make(bias, dtype::Float32()).node() | |||||
: bias; | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||||
auto conv_bias_typecvt = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype_typecvt}); | |||||
rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(), | |||||
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||||
"to conv_bias(NHWC)")); | |||||
return true; | |||||
}; | |||||
auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) { | |||||
if (!try_conv_typecvt(opr)) { | |||||
rewriter.auto_replace_outputs(opr); | |||||
} | |||||
}; | |||||
opt.graph().iter(on_opr); | |||||
rewriter.apply_inplace(); | |||||
MIDOUT_E | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | add_pass<FuseWarpPerspectiveDimshufflePass>(); | ||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
add_pass<FoldingConvBiasDimshufflePass>(); | add_pass<FoldingConvBiasDimshufflePass>(); | ||||
add_pass<FoldingConvBiasTypecvtPass>(); | |||||
#endif | #endif | ||||
}); | }); | ||||
#undef cb | #undef cb | ||||
@@ -57,7 +57,10 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
TensorFormats::NCHW, TensorFormats::NHWC, | TensorFormats::NCHW, TensorFormats::NHWC, | ||||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | TensorFormats::NCHWc4, TensorFormats::NCHWc32, | ||||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA}; | |||||
Attribute attribute = { | |||||
base_opr_format, base_tensor_format, Target::CUDA, | |||||
LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | |||||
auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
std::move(opr_list), std::move(available_tensor_formats), | std::move(opr_list), std::move(available_tensor_formats), | ||||
attribute); | attribute); | ||||
@@ -67,8 +70,9 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | ||||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | .add_opr_config(opr::ConvolutionForward::typeinfo(), | ||||
{OprFormat::NCHW, OprFormat::NCHW4}) | {OprFormat::NCHW, OprFormat::NCHW4}) | ||||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||||
.add_opr_config( | |||||
opr::ConvolutionBackwardData::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) | |||||
.add_opr_config( | .add_opr_config( | ||||
opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | ||||
@@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, | |||||
const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
config.opr_format = OprFormat::NCHW4; | |||||
config.opr_format = OprFormat::NHWC; | |||||
bool available = true; | bool available = true; | ||||
for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
available &= | available &= | ||||
@@ -481,6 +481,12 @@ namespace gopt { | |||||
const char* name() const override; | const char* name() const override; | ||||
void apply(OptState& opt) const override; | void apply(OptState& opt) const override; | ||||
}; | }; | ||||
class FoldingConvBiasTypecvtPass final : public Pass { | |||||
public: | |||||
const char* name() const override; | |||||
void apply(OptState& opt) const override; | |||||
}; | |||||
#endif | #endif | ||||
/*! | /*! | ||||
@@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
using OprFormat = LayoutTransformContext::OprFormat; | using OprFormat = LayoutTransformContext::OprFormat; | ||||
using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||||
using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
OprList opr_list = { | OprList opr_list = { | ||||
opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
@@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
TensorFormats::NCHW, TensorFormats::NHWC, | TensorFormats::NCHW, TensorFormats::NHWC, | ||||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | TensorFormats::NCHWc4, TensorFormats::NCHWc32, | ||||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||||
Target::UNSPEC}; | |||||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
ReformatAttribute::AUTO_PADDING_NHWC}; | |||||
auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
std::move(opr_list), std::move(available_tensor_formats), | std::move(opr_list), std::move(available_tensor_formats), | ||||
attribute); | attribute); | ||||
@@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | ||||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | .add_opr_config(opr::ConvolutionForward::typeinfo(), | ||||
{OprFormat::NCHW, OprFormat::NCHW4}) | {OprFormat::NCHW, OprFormat::NCHW4}) | ||||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||||
.add_opr_config( | |||||
opr::ConvolutionBackwardData::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) | |||||
.add_opr_config( | .add_opr_config( | ||||
opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | ||||
@@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
.add_pass<ShuffleShuffleRemovePass>() | .add_pass<ShuffleShuffleRemovePass>() | ||||
.add_pass(FuseNCHW4Int8Preprocess::make()) | .add_pass(FuseNCHW4Int8Preprocess::make()) | ||||
.add_pass<FoldingConvBiasDimshufflePass>() | .add_pass<FoldingConvBiasDimshufflePass>() | ||||
.add_pass<FoldingConvBiasTypecvtPass>() | |||||
.add_pass<ParamFusePass>() | .add_pass<ParamFusePass>() | ||||
.add_pass<ParamMergePass>() | .add_pass<ParamMergePass>() | ||||
.apply(SymbolVarArray{y}) | .apply(SymbolVarArray{y}) | ||||
@@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
/// check first conv format | /// check first conv format | ||||
const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | ||||
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | ||||
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC); | |||||
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC); | |||||
ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm); | |||||
} | } | ||||
#endif | #endif | ||||
#endif | #endif | ||||