|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- /**
- * \file dnn/src/arm_common/conv_bias/opr_impl.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
-
- #include "src/arm_common/conv_bias/int8/algos.h"
- #include "src/arm_common/conv_bias/int8x8x16/algos.h"
- #include "src/arm_common/conv_bias/quint8/algos.h"
-
- #include "src/arm_common/conv_bias/opr_impl.h"
- #include "src/common/metahelper.h"
- #include "src/common/utils.h"
- #include "src/naive/handle.h"
-
- #include "src/arm_common/convolution/opr_impl.h"
- #include "src/arm_common/matrix_mul/opr_impl.h"
- #include "src/common/opr_delegate.h"
-
- #include "include/megdnn/oprs/nn.h"
- #include "src/arm_common/conv_bias/f16/algos.h"
- #include "src/arm_common/conv_bias/fp32/algos.h"
- #include "src/arm_common/conv_bias/int8/stride1.h"
- #include "src/arm_common/conv_bias/int8/stride2.h"
- #include "src/arm_common/conv_bias/quint8/stride1.h"
- #include "src/arm_common/conv_bias/quint8/stride2.h"
- #include "src/arm_common/convolution/opr_impl.h"
-
- using namespace megdnn;
- using namespace arm_common;
-
- namespace {
- uint8_t arm_common_algo_type_storage;
- } // anonymous namespace
-
- class ConvBiasImpl::AlgoPack : NonCopyableObj {
- AlgoQU8DirectStride2 qu8_direct_stride2;
- AlgoQU8DirectStride1 qu8_direct_stride1;
- AlgoS8DirectStride2 s8_direct_stride2;
- AlgoS8DirectNCHW44 s8_direct_nchw44;
- AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
- AlgoS8DirectStride1 s8_direct_stride1;
- AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
- AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
- AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44;
-
- #if __ARM_FEATURE_DOTPROD
- AlgoDotS8DirectStride1 ds8_direct_stride1;
- AlgoDotS8DirectStride2 ds8_direct_stride2;
- AlgoDotU8DirectStride1 du8_direct_stride1;
- AlgoDotU8DirectStride2 du8_direct_stride2;
-
- AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
- AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
- #endif
-
- AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
- AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
- AlgoF32DirectNCHW44 f32_direct_nchw44;
-
- AlgoF32Direct f32_direct;
- AlgoF32DirectStride2 f32_direct_stride2;
- AlgoF32DirectStride1 f32_direct_stride1;
-
- AlgoI8x8x16Direct i8x8x16_direct;
- AlgoI8x8x16Stride2 i8x8x16_stride2;
- AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
- AlgoI8x8x16DirectNCHWNCHW44 i8x8x16_nchw_nchw44;
- #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- AlgoF16Direct f16_direct;
- AlgoF16DirectStride1 f16_direct_stride1;
- #endif
-
- SmallVector<std::unique_ptr<AlgoBase>> refhold;
-
- public:
- AlgoPack() {
- #if __ARM_FEATURE_DOTPROD
- direct_algos.emplace_back(&ds8_direct_stride1);
- direct_algos.emplace_back(&ds8_direct_stride2);
- direct_algos.emplace_back(&du8_direct_stride1);
- direct_algos.emplace_back(&du8_direct_stride2);
-
- direct_algos.emplace_back(&ds8_direct_nchw44);
- direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
- #endif
- direct_algos.emplace_back(&qu8_direct_stride2);
- direct_algos.emplace_back(&qu8_direct_stride1);
- direct_algos.emplace_back(&s8_direct_stride2);
- direct_algos.emplace_back(&s8_direct_nchw44);
- direct_algos.emplace_back(&s8_direct_nchw_nchw44);
- direct_algos.emplace_back(&s8_direct_stride1);
-
- direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44);
- direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
- direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
-
- #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- direct_algos.emplace_back(&f16_direct_stride1);
- direct_algos.emplace_back(&f16_direct);
- #endif
- direct_algos.emplace_back(&i8x8x16_direct);
- direct_algos.emplace_back(&i8x8x16_stride2_filter2);
- direct_algos.emplace_back(&i8x8x16_stride2);
- direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
-
- direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
- direct_algos.emplace_back(&f32_chanel_wise_nchw44);
- direct_algos.emplace_back(&f32_direct_nchw44);
-
- direct_algos.emplace_back(&f32_direct_stride1);
- direct_algos.emplace_back(&f32_direct_stride2);
- direct_algos.emplace_back(&f32_direct);
-
- static CpuOprDelegationStorage<2> storage;
- auto matmul_opr = storage.get<MatrixMul, 0>();
- auto&& matmul_algos =
- static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
- ->algo_pack();
- for (auto&& algo : matmul_algos) {
- if (algo->type() == nullptr)
- continue;
- for (uint32_t tile_size : {16, 8, 24, 32}) {
- refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF63(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF54(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF45(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- //! uncomment this when low precision mode is done
- #if 0
- refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- #endif
- #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- refhold.emplace_back(new AlgoFP16WinogradF23(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP16WinogradF45(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP16WinogradF63(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- #endif
- refhold.emplace_back(new AlgoS8WinogradF23_8x8(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
- static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
- tile_size));
- winograd_algos.emplace_back(refhold.back().get());
- }
- }
- }
- SmallVector<AlgoBase*> direct_algos;
- SmallVector<AlgoBase*> winograd_algos;
- };
-
- SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
- static AlgoPack sl_algo_pack;
- auto&& algos = fallback::ConvBiasImpl::algo_pack();
- algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
- sl_algo_pack.direct_algos.end());
- algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
- sl_algo_pack.winograd_algos.end());
- return std::move(algos);
- }
-
- void* const ConvBiasImpl::sm_arm_common_algo_type =
- &arm_common_algo_type_storage;
-
- bool ConvBiasImpl::is_matmul_quantized_prefer(
- const ConvBiasImpl::NCBKernSizeParam& param) const {
- fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(
- param, 0, param::MatrixMul::Format::DEFAULT, {}, 0,
- BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);
- conv_ncb_param.dst_type = param.bias_type;
- conv_ncb_param.filter_meta.group = 1;
-
- bool conv_direct_unusable = false;
- if (param.dst_type.enumv() == DTypeEnum::QuantizedS8 ||
- param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
- conv_direct_unusable =
- !arm_common::direct_int8_stride1::can_conv_direct_stride1_int8(
- conv_ncb_param) &&
- !arm_common::direct_int8_stride2::can_conv_direct_stride2_int8(
- conv_ncb_param);
- } else if (param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) {
- conv_direct_unusable =
- !arm_common::direct_quint8_stride1::
- can_conv_direct_stride1_quint8(conv_ncb_param) &&
- !arm_common::direct_quint8_stride2::
- can_conv_direct_stride2_quint8(conv_ncb_param);
- }
- return conv_direct_unusable;
- }
-
- const char* ConvBiasImpl::get_algorithm_set_name() const {
- // arm common version 0
- return "AC0";
- }
-
- // vim: syntax=cpp.doxygen
|