GitOrigin-RevId: adc2301203
release-1.7
@@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||
auto&& param = args.opr->param(); | |||
bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | |||
bool is_version_ok = CUDNN_VERSION >= 7500; | |||
bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; | |||
bool is_dtype_ok = | |||
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | |||
bool is_bias_ok = | |||
args.bias_layout->ndim == 0 || | |||
(args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | |||
@@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||
} | |||
} | |||
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
return false; | |||
} | |||
// FIXME: cudnn cannot handle the case when the initial value of dst tensor | |||
// contains nan and beta is zero, because the result of 0.f * nan is still | |||
// nan | |||
@@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( | |||
if (!is_compute_capability_required(6, 1)) | |||
return false; | |||
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||
args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
return false; | |||
} | |||
auto dst_layout = *args.dst_layout; | |||
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||
dst_layout.dtype = DType(); | |||
@@ -0,0 +1,151 @@ | |||
/** | |||
* \file src/gopt/impl/folding_conv_typecvt.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "midout.h" | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#if CUDA_VERSION >= 10020 | |||
MIDOUT_DECL(megbrain_folding_conv_typecvt) | |||
#define MIDOUT_B(tag) \ | |||
MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) { | |||
#define MIDOUT_E \ | |||
} \ | |||
MIDOUT_END(); | |||
using namespace mgb; | |||
using namespace gopt; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
/* ==================== FoldingConvBiasTypecvtPass ================= */ | |||
const char* FoldingConvBiasTypecvtPass::name() const { | |||
return mgb_cstr_log("folding conv bias typecvt pass"); | |||
} | |||
void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||
MIDOUT_B("FoldingConvBiasTypecvtPass::apply"); | |||
using DepType = cg::OperatorNodeProp::DepType; | |||
ThinHashMap<OperatorNodeBase*, | |||
SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
readers; | |||
static const ThinHashSet<Typeinfo*> opr_type_list = { | |||
opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()}; | |||
opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||
for (auto&& i : opr->node_prop().dep_map()) { | |||
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||
readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||
} | |||
} | |||
}); | |||
auto rewriter = opt.graph().make_rewriter(); | |||
auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check typecvt | |||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||
if (typecvt == nullptr) | |||
return false; | |||
auto inp_dtype_typecvt = typecvt->input(0)->dtype(), | |||
out_dtype_typecvt = typecvt->output(0)->dtype(); | |||
bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||
out_dtype_typecvt.enumv() == DTypeEnum::Float32; | |||
bool is_s82s4 = | |||
inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||
(out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||
out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm); | |||
bool is_s42s8 = | |||
(inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||
inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) && | |||
out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8; | |||
if (!(is_s82f32 || is_s82s4 || is_s42s8)) | |||
return false; | |||
opr_set.insert(opr); | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
auto inp_dtype_conv = conv_bias->input(0)->dtype(), | |||
out_dtype_conv = conv_bias->input(0)->dtype(); | |||
bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | |||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NHWC; | |||
bool is_s4nhwc = | |||
(inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 || | |||
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | |||
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NHWC; | |||
if (!(is_s8nhwc || is_s4nhwc)) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = | |||
(out_dtype_typecvt.enumv() == DTypeEnum::Float32) | |||
? opr::TypeCvt::make(bias, dtype::Float32()).node() | |||
: bias; | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||
auto conv_bias_typecvt = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
OperatorNodeConfig{out_dtype_typecvt}); | |||
rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(), | |||
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||
"to conv_bias(NHWC)")); | |||
return true; | |||
}; | |||
auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) { | |||
if (!try_conv_typecvt(opr)) { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
opt.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
#if CUDA_VERSION >= 10020 | |||
add_pass<FoldingConvBiasDimshufflePass>(); | |||
add_pass<FoldingConvBiasTypecvtPass>(); | |||
#endif | |||
}); | |||
#undef cb | |||
@@ -57,7 +57,10 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
TensorFormats::NCHW, TensorFormats::NHWC, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||
Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA}; | |||
Attribute attribute = { | |||
base_opr_format, base_tensor_format, Target::CUDA, | |||
LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | |||
auto ctx = std::make_unique<LayoutTransformContext>( | |||
std::move(opr_list), std::move(available_tensor_formats), | |||
attribute); | |||
@@ -67,8 +70,9 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config( | |||
opr::ConvolutionBackwardData::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) | |||
.add_opr_config( | |||
opr::PoolingForward::typeinfo(), | |||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||
@@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW4; | |||
config.opr_format = OprFormat::NHWC; | |||
bool available = true; | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
available &= | |||
@@ -481,6 +481,12 @@ namespace gopt { | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
}; | |||
class FoldingConvBiasTypecvtPass final : public Pass { | |||
public: | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
}; | |||
#endif | |||
/*! | |||
@@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
using OprFormat = LayoutTransformContext::OprFormat; | |||
using OprList = LayoutTransformContext::OprList; | |||
using Attribute = LayoutTransformContext::Attribute; | |||
using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||
using Target = LayoutTransformContext::Target; | |||
OprList opr_list = { | |||
opr::ConvBiasForward::typeinfo(), | |||
@@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
TensorFormats::NCHW, TensorFormats::NHWC, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||
TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||
Target::UNSPEC}; | |||
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||
ReformatAttribute::AUTO_PADDING_NHWC}; | |||
auto ctx = std::make_unique<LayoutTransformContext>( | |||
std::move(opr_list), std::move(available_tensor_formats), | |||
attribute); | |||
@@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||
.add_opr_config(opr::ConvolutionForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW4}) | |||
.add_opr_config( | |||
opr::ConvolutionBackwardData::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) | |||
.add_opr_config( | |||
opr::PoolingForward::typeinfo(), | |||
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||
@@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
.add_pass<ShuffleShuffleRemovePass>() | |||
.add_pass(FuseNCHW4Int8Preprocess::make()) | |||
.add_pass<FoldingConvBiasDimshufflePass>() | |||
.add_pass<FoldingConvBiasTypecvtPass>() | |||
.add_pass<ParamFusePass>() | |||
.add_pass<ParamMergePass>() | |||
.apply(SymbolVarArray{y}) | |||
@@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
/// check first conv format | |||
const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | |||
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC); | |||
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC); | |||
ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm); | |||
} | |||
#endif | |||
#endif | |||