GitOrigin-RevId: 60942aca5b
tags/v1.0.0-rc1
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
@@ -314,8 +315,10 @@ public: | |||
/** | |||
* \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic) | |||
* \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw, | |||
* 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out] | |||
* dst (n, oc, oh, ow) or (n, oh, ow, oc) | |||
* 4 * ic) | |||
* \param[in] bias (1, oc, 1, 1) | |||
* \param[in] z same as dst | |||
* \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc) | |||
* | |||
* \note if the format is NCHW_WINOGRAD, the filter layout is (alphah, | |||
* alphaw, oc, ic) | |||
@@ -407,6 +410,26 @@ public: | |||
*/ | |||
static WinogradParam parse_winograd_name(const std::string& algo_name); | |||
/** | |||
* @brief find if there is nchw_nchwxx conv kernel optimized for argment, | |||
* nchw44 used for arm, nchw88 used for x86 | |||
* | |||
* @param src_dtype conv feature map data type | |||
* @param filter_dtype conv filter or weight data type | |||
* @param dst_dtype output data type | |||
* @param fm filter meta param | |||
* @param bias_mode bias mode, no_bias or broadcast or bias | |||
* @param nonline_mode identity or relu or h_swish or sigmoid | |||
* @return true, found a kernel | |||
* @return false, can`t found any kernel | |||
*/ | |||
static bool is_nchw_nchwxx_optimized( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const ConvBiasForward::BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode); | |||
protected: | |||
CanonizedFilterMeta check_exec( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
@@ -16,10 +16,10 @@ | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using conv_fun = std::function<void( | |||
@@ -191,22 +191,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
auto&& fm = param.filter_meta; | |||
auto fh = fm.spatial[0]; | |||
int oc = fm.ocpg; | |||
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
(param.dst_type.enumv() == DTypeEnum::Float32))) && | |||
(fm.format == param::Convolution::Format::NCHW44); | |||
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[0] == 2); | |||
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
return avaible; | |||
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>( | |||
param.src_type.enumv(), param.filter_type.enumv(), | |||
param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
param.nonlineMode); | |||
} | |||
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | |||
@@ -15,6 +15,7 @@ | |||
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/int8/strategy.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
@@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
ow, op); | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const { | |||
MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
bool avaible = //! src and filter are qint8, dst is qint8 | |||
fm.icpg < 4 && // must be nchw input | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44) && | |||
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && | |||
param.bias_mode != BiasMode::BIAS; | |||
return avaible; | |||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>( | |||
param.src_type.enumv(), param.filter_type.enumv(), | |||
param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
param.nonlineMode); | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||
@@ -16,6 +16,7 @@ | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "midout.h" | |||
@@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
auto&& fm = param.filter_meta; | |||
auto fh = fm.spatial[0]; | |||
int oc = fm.ocpg; | |||
int ic = fm.icpg; | |||
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44_DOT); | |||
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); | |||
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[0] == 2); | |||
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
return avaible; | |||
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>( | |||
param.src_type.enumv(), param.filter_type.enumv(), | |||
param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
param.nonlineMode); | |||
} | |||
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | |||
@@ -16,6 +16,7 @@ | |||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_op.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "midout.h" | |||
@@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
auto&& fm = param.filter_meta; | |||
auto fh = fm.spatial[0]; | |||
int oc = fm.ocpg; | |||
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
(param.dst_type.enumv() == DTypeEnum::Int16))) && | |||
(fm.format == param::Convolution::Format::NCHW44); | |||
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 2 || fm.stride[0] == 1); | |||
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS && | |||
param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY; | |||
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
return avaible; | |||
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>( | |||
param.src_type.enumv(), param.filter_type.enumv(), | |||
param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
param.nonlineMode); | |||
} | |||
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( | |||
@@ -0,0 +1,43 @@ | |||
/** | |||
* \file dnn/src/common/nchw_nchwxx_valid.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs/nn.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
using namespace megdnn; | |||
namespace { | |||
using NchwNchwxxFuncInterface = std::function<bool( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const ConvBiasForward::BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode)>; | |||
static SmallVector<NchwNchwxxFuncInterface> g_func_vec{ | |||
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>, | |||
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>, | |||
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>, | |||
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>, | |||
nchw_nchwxx_valid<NchwNchwxxType::NCHW88>, | |||
}; | |||
} // namespace | |||
bool ConvBiasForward::is_nchw_nchwxx_optimized( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const ConvBiasForward::BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
for (auto& func : g_func_vec) { | |||
if (func(src_dtype, filter_dtype, dst_dtype, fm, bias_mode, | |||
nonline_mode)) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} |
@@ -0,0 +1,161 @@ | |||
/** | |||
* \file dnn/src/common/nchw_nchwxx_valid.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace { | |||
enum NchwNchwxxType { | |||
NCHW44_FP32, | |||
NCHW44_INT8, | |||
NCHW44_INT8_INT8_INT16, | |||
NCHW44_INT8_DOT, | |||
NCHW88, | |||
}; | |||
template <NchwNchwxxType T> | |||
static inline bool nchw_nchwxx_valid( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode); | |||
template <> | |||
inline bool nchw_nchwxx_valid<NCHW44_FP32>( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
bool ok_type = ((src_dtype == DTypeEnum::Float32 && | |||
filter_dtype == DTypeEnum::Float32 && | |||
(dst_dtype == DTypeEnum::Float32))) && | |||
(fm.format == param::Convolution::Format::NCHW44); | |||
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[1] == 2); | |||
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
ok_slide && ok_conv; | |||
return avaible; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<NCHW44_INT8>( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && | |||
filter_dtype == DTypeEnum::QuantizedS8 && | |||
(dst_dtype == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44); | |||
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[1] == 2); | |||
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
ok_slide && ok_conv; | |||
return avaible; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<NCHW44_INT8_INT8_INT16>( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
bool ok_type = | |||
((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 && | |||
(dst_dtype == DTypeEnum::Int16))) && | |||
(fm.format == param::Convolution::Format::NCHW44); | |||
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 2 || fm.stride[0] == 1); | |||
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
ok_slide && ok_conv; | |||
return avaible; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && | |||
filter_dtype == DTypeEnum::QuantizedS8 && | |||
(dst_dtype == DTypeEnum::QuantizedS8))) && | |||
(fm.format == param::Convolution::Format::NCHW44_DOT); | |||
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[1] == 2); | |||
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
ok_slide && ok_conv; | |||
return avaible; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<NCHW88>( | |||
const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
const DTypeEnum dst_dtype, | |||
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
const BiasMode bias_mode, | |||
const param::ConvBias::NonlineMode nonline_mode) { | |||
bool ok_type = ((src_dtype == DTypeEnum::Float32 && | |||
filter_dtype == DTypeEnum::Float32 && | |||
(dst_dtype == DTypeEnum::Float32))) && | |||
(fm.format == param::Convolution::Format::NCHW88); | |||
bool ok_src_dst = | |||
fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1; | |||
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv; | |||
return avaible; | |||
} | |||
} // namespace | |||
} // namespace megdnn |
@@ -11,6 +11,7 @@ | |||
*/ | |||
#pragma once | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "src/x86/conv_bias/opr_impl.h" | |||
using namespace megdnn; | |||
@@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
@@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, | |||
const CpuNDRange& workspace_ids); | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
@@ -163,13 +166,19 @@ public: | |||
AlgoSelectionStrategy) const override { | |||
auto&& fm = param.filter_meta; | |||
bool ok = (fm.format == param::ConvBias::Format::NCHW88) && | |||
fm.spatial_ndim == 2 && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
return ok; | |||
bool nchw_nchw88_ok = nchw_nchwxx_valid<NchwNchwxxType::NCHW88>( | |||
param.src_type.enumv(), param.filter_type.enumv(), | |||
param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
param.nonlineMode); | |||
bool normal_conv_ok = (fm.format == param::ConvBias::Format::NCHW88) && | |||
fm.spatial_ndim == 2 && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
return nchw_nchw88_ok || normal_conv_ok; | |||
}; | |||
size_t get_workspace(const NCBKernSizeParam&) const override { return 0; } | |||
@@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||
} | |||
template <typename OprType> | |||
static inline bool nchw_nchwxx_valid(const OprType& opr, | |||
const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, | |||
bool is_dot = false); | |||
template <> | |||
inline bool nchw_nchwxx_valid<opr::ConvolutionForward>( | |||
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, bool is_dot) { | |||
auto& filter_shape = new_inp[1]->shape(); | |||
auto filter_dtype = new_inp[1]->dtype(); | |||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
filter_dtype.enumv() == DTypeEnum::Int8; | |||
const size_t oc = filter_shape[0]; | |||
const size_t ic = filter_shape[1]; | |||
bool is_like_nchw_nchwxx = | |||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||
if (!is_like_nchw_nchwxx) { | |||
static inline bool nchw_nchwxx_valid( | |||
const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size, | |||
megdnn::param::ConvBias::NonlineMode nonline_mode = | |||
megdnn::param::ConvBias::NonlineMode::IDENTITY, | |||
bool is_dot = false) { | |||
auto& src_node = new_inp[0]; | |||
auto& filter_node = new_inp[1]; | |||
auto dst_node = opr.output(0); | |||
if (filter_node->shape().ndim != 4) { | |||
return false; | |||
} | |||
SmallVector<TensorLayout> layouts; | |||
//! src | |||
layouts.push_back( | |||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
//! weight | |||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
filter_shape[3], filter_shape[1], pack_size}, | |||
new_inp[1]->dtype(), | |||
new_inp[1]->format()}); | |||
auto out0 = opr.output(0); | |||
auto& out_shape = out0->shape(); | |||
//! FIXME: return false if oc is invalid | |||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
out_shape[3], pack_size}, | |||
out0->dtype(), | |||
out0->format()}); | |||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
->create_operator<megdnn::ConvolutionForward>(); | |||
megdnn_conv.get()->param() = opr.param(); | |||
//! set by dtype | |||
switch (pack_size) { | |||
case 4: | |||
if (is_dot && is_int8) { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||
} else { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44; | |||
} | |||
break; | |||
case 8: | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW88; | |||
break; | |||
default: | |||
break; | |||
} | |||
bool find_valid_algo = false; | |||
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], | |||
layouts[2]); | |||
for (auto i : algos) { | |||
if (i->type() != nullptr) { | |||
find_valid_algo = true; | |||
megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm; | |||
fm.format = megdnn::param::Convolution::Format::NCHW; | |||
fm.should_flip = | |||
opr.param().mode == megdnn::ConvBiasForward::Mode::CONVOLUTION; | |||
fm.group = 1; | |||
fm.spatial_ndim = 2; | |||
fm.ocpg = filter_node->shape()[0]; | |||
fm.icpg = filter_node->shape()[1]; | |||
fm.spatial[0] = filter_node->shape()[2]; | |||
fm.spatial[1] = filter_node->shape()[3]; | |||
fm.stride[0] = opr.param().stride_h; | |||
fm.stride[1] = opr.param().stride_w; | |||
fm.padding[0] = opr.param().pad_h; | |||
fm.padding[1] = opr.param().pad_w; | |||
fm.dilation[0] = opr.param().dilate_h; | |||
fm.dilation[1] = opr.param().dilate_w; | |||
megdnn::ConvBiasForward::BiasMode bias_mode = | |||
megdnn::ConvBiasForward::BiasMode::NO_BIAS; | |||
if (std::is_same<OprType, opr::ConvBiasForward>::value) { | |||
auto& bias_shape = new_inp[2]->shape(); | |||
if (bias_shape.ndim == 0) { | |||
bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; | |||
} else if (bias_shape.eq_shape(dst_node->shape())) { | |||
bias_mode = megdnn::ConvBiasForward::BiasMode::BIAS; | |||
} else { | |||
//! just check the ndim, the detail shape check is in check_exec | |||
mgb_assert(bias_shape.ndim == dst_node->shape().ndim); | |||
bias_mode = | |||
megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS; | |||
} | |||
} | |||
return find_valid_algo; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>( | |||
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, bool is_dot) { | |||
auto& filter_shape = new_inp[1]->shape(); | |||
auto filter_dtype = new_inp[1]->dtype(); | |||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
filter_dtype.enumv() == DTypeEnum::Int8; | |||
const size_t oc = filter_shape[0]; | |||
const size_t ic = filter_shape[1]; | |||
bool is_like_nchw_nchwxx = | |||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||
if (!is_like_nchw_nchwxx) { | |||
return false; | |||
} | |||
SmallVector<TensorLayout> layouts; | |||
//! src | |||
layouts.push_back( | |||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
//! weight | |||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
filter_shape[3], filter_shape[1], pack_size}, | |||
new_inp[1]->dtype(), | |||
new_inp[1]->format()}); | |||
auto& bias_shape = new_inp[2]->shape(); | |||
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], | |||
bias_shape[3], pack_size}, | |||
new_inp[2]->dtype(), | |||
new_inp[2]->format()}); | |||
auto out0 = opr.output(0); | |||
auto& out_shape = out0->shape(); | |||
//! FIXME: return false if oc is invalid | |||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
out_shape[3], pack_size}, | |||
out0->dtype(), | |||
out0->format()}); | |||
// megdnn::ConvolutionForward | |||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
->create_operator<megdnn::ConvBiasForward>(); | |||
megdnn_conv.get()->param() = opr.param(); | |||
//! FIXME: set by dtype | |||
switch (pack_size) { | |||
case 4: | |||
if (is_dot && is_int8) { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||
} else { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44; | |||
} | |||
break; | |||
case 8: | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW88; | |||
break; | |||
default: | |||
break; | |||
} | |||
bool find_valid_algo = false; | |||
auto algos = megdnn_conv.get()->get_all_algorithms( | |||
layouts[0], layouts[1], layouts[2], {}, layouts[3]); | |||
for (auto i : algos) { | |||
if (i->type() != nullptr) { | |||
find_valid_algo = true; | |||
if (pack_size == 4) { | |||
if (is_dot && filter_node->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||
fm.format = megdnn::param::Convolution::Format::NCHW44_DOT; | |||
} else { | |||
fm.format = megdnn::param::Convolution::Format::NCHW44; | |||
} | |||
} else if (pack_size == 8) { | |||
fm.format = megdnn::param::Convolution::Format::NCHW88; | |||
} else { | |||
mgb_assert(0, "only support nchw44 nchw88"); | |||
} | |||
return find_valid_algo; | |||
return megdnn::ConvBiasForward::is_nchw_nchwxx_optimized( | |||
src_node->dtype().enumv(), filter_node->dtype().enumv(), | |||
dst_node->dtype().enumv(), fm, bias_mode, nonline_mode); | |||
} | |||
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
using TestFilterResult = std::pair<TransType, RelayoutMode>; | |||
@@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
megdnn::param::Pooling::Format pooling_format = | |||
megdnn::param::Pooling::Format::NCHW88; | |||
std::string convter_pass_name = "conv_format_nchw88"; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMv7 | |||
if (pack_c_size == 8) { | |||
mgb_log_error( | |||
"runtime backend is ARM, but nchw88 only support X86, you may " | |||
"have performance loss\n"); | |||
} | |||
#elif MEGDNN_X86 | |||
if (pack_c_size == 4) { | |||
mgb_log_error( | |||
"runtime backend is X86, but nchw44 only support arm, you may " | |||
"have performance loss\n"); | |||
} | |||
#endif | |||
if (pack_c_size == 4) { | |||
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | |||
@@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
mgb_assert(conv_opr.param().format == | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
bool is_dense = conv_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); | |||
auto is_trans = test_trans_nchwxx( | |||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||
@@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
bool is_dense = conv_bias_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
pack_c_size, is_dense); | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | |||
conv_bias_opr.param().nonlineMode); | |||
auto is_trans = test_trans_nchwxx( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
@@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
MIDOUT_B("EnableNchw44DotPass::make") | |||
auto ret = std::make_unique<EnableNchw44DotPass>(); | |||
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | |||
//! First is whether the conv can trans to nchwxx, second is the filter | |||
//! trans mode | |||
#if MEGDNN_X86 | |||
mgb_log_error( | |||
"backend is X86, but nchw44_dot only support arm, you may have " | |||
"performance loss\n"); | |||
#endif | |||
//! First is whether the conv can trans to nchwxx, second is the filter | |||
//! trans mode | |||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
struct TestTransResult { | |||
@@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to " | |||
"NCHW44_DOT"); | |||
bool is_dense = conv_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
bool valid_nchw_nchw44 = nchw_nchwxx_valid( | |||
conv_opr, new_inp, pack_c_size, | |||
megdnn::param::ConvBias::NonlineMode::IDENTITY, true); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||
//! can not trans to nchwxx | |||
if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
bool is_dense = conv_bias_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
pack_c_size, is_dense); | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | |||
conv_bias_opr.param().nonlineMode, true); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
@@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||
OperatorNodeConfig("conv1")); | |||
//! channel wise | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
@@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { | |||
options.enable_nchw88(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
} | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, | |||
find_opr<opr::ConvBias>(y_opt).param().format); | |||
@@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
#else | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
#endif | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
@@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
options.enable_nchw44_dot(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
#else | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
#endif | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
@@ -611,11 +611,11 @@ public: | |||
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||
"workspace=%.2fMiB reproducible=%d", | |||
mgb_opr->dyn_typeinfo()->name, | |||
layouts[0].TensorShape::to_string().c_str(), | |||
layouts[0].to_string().c_str(), | |||
layouts[0].dtype.name(), | |||
layouts[1].TensorShape::to_string().c_str(), | |||
layouts[1].to_string().c_str(), | |||
layouts[1].dtype.name(), | |||
layouts[layouts.size() - 1].TensorShape::to_string().c_str(), | |||
layouts[layouts.size() - 1].to_string().c_str(), | |||
layouts[layouts.size() - 1].dtype.name(), | |||
algo->name(), | |||
workspace / (1024 * 1024.0), algo->is_reproducible()); | |||