diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp deleted file mode 100644 index 3f822c22..00000000 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp +++ /dev/null @@ -1,725 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include - -#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" -#include "src/arm_common/conv_bias/postprocess_helper.h" -#include "src/arm_common/simd_macro/neon_helper.h" - -#include "midout.h" - -MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) - -using namespace megdnn; -using namespace arm_common; -using namespace fp32; -using namespace conv_stride1; - -using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; -using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; - -void conv_stride1::do_conv_2x2_stride1( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - OW; - //! unroll of 2 - size_t ic = 0; - for (; ic + 1 < IC; ic += 2) { - const float* src_ptr = src + IW * IH * ic; - const float* src_ptr1 = src_ptr + IW * IH; - float* outptr = dst; - - const float* r00 = src_ptr; - const float* r01 = src_ptr + IW; - const float* r10 = src_ptr1; - const float* r11 = src_ptr1 + IW; - - const float* k0 = filter + ic * 4; - const float* k1 = k0 + 4; - - MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); - MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); - rep(h, OH) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); - MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); - MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); - MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); - - MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); - MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); - MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); - MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); - - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, MEGDNN_SIMD_GET_LOW(_k0), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, MEGDNN_SIMD_GET_LOW(_k0), 1); - _sum = MEGDNN_SIMD_VMLAQ_LANE( - _sum, _r010, MEGDNN_SIMD_GET_HIGH(_k0), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE( - _sum, _r011, MEGDNN_SIMD_GET_HIGH(_k0), 1); - - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, MEGDNN_SIMD_GET_LOW(_k1), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, MEGDNN_SIMD_GET_LOW(_k1), 1); - _sum = MEGDNN_SIMD_VMLAQ_LANE( - _sum, _r110, MEGDNN_SIMD_GET_HIGH(_k1), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE( - _sum, _r111, MEGDNN_SIMD_GET_HIGH(_k1), 1); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r00 += 4; - r01 += 4; - r10 += 4; - r11 += 4; - outptr += 4; - } - - r00 += tail_step; - r01 += tail_step; - r10 += tail_step; - r11 += tail_step; - } - } - for (; ic < IC; ic++) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - - const float* k0 = filter + ic * 4; - - MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); - MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); - MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); - MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); - rep(h, OH) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); - - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - MEGDNN_SIMD_TYPE _sum2; - - _sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); - _sum2 = MEGDNN_SIMD_MUL(_r01, _k1); - _sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); - _sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); - - _sum = MEGDNN_SIMD_ADD(_sum, _sum2); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r0 += 4; - r1 += 4; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - } - } -} - -void conv_stride1::do_conv_3x3_stride1( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - OW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - float* outptr2 = outptr + OW; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - const float* r3 = src_ptr + IW * 3; - - const float* k0 = filter; - const float* k1 = filter + 3; - const float* k2 = filter + 5; - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); - MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); - MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); - MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); - - size_t h = 0; - for (; h + 1 < OH; h += 2) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); - MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); - MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); - MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); - - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); - MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); - MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); - - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); - - MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); - MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); - MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); - - MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); - MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4); - MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); - - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); - - _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); - _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); - _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); - _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); - _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); - _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); - _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); - _sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); - _sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); - - _sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); - _sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); - - MEGDNN_SIMD_STOREU(outptr, _sum1); - MEGDNN_SIMD_STOREU(outptr2, _sum3); - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - outptr += 4; - outptr2 += 4; - } - - r0 += tail_step + IW; - r1 += tail_step + IW; - r2 += tail_step + IW; - r3 += tail_step + IW; - - outptr += OW; - outptr2 += OW; - } - - for (; h < OH; h++) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); - MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); - - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); - MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); - MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); - - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); - - MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); - MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); - MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); - - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); - _sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); - - _sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); - - MEGDNN_SIMD_STOREU(outptr, _sum1); - - r0 += 4; - r1 += 4; - r2 += 4; - outptr += 4; - } - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - } - - filter += 9; - } -} - -void conv_stride1::do_conv_5x5_stride1( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - OW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - float* outptr2 = outptr + OW; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - const float* r3 = src_ptr + IW * 3; - const float* r4 = src_ptr + IW * 4; - const float* r5 = src_ptr + IW * 5; - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); - MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); - MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); - MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); - MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); - MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); - MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); - - size_t h = 0; - for (; h + 1 < OH; h += 2) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); - - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); - MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); - MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); - MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); - - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); - MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); - - MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); - MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); - MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); - MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); - - MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); - MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); - MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); - MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); - - MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); - MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); - MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); - MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); - MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); - - MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); - MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); - MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); - MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); - MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); - - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); - - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); - - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); - - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); - - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); - _sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); - - MEGDNN_SIMD_STOREU(outptr, _sum); - MEGDNN_SIMD_STOREU(outptr2, _sum2); - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - r5 += 4; - outptr += 4; - outptr2 += 4; - } - - r0 += tail_step + IW; - r1 += tail_step + IW; - r2 += tail_step + IW; - r3 += tail_step + IW; - r4 += tail_step + IW; - r5 += tail_step + IW; - - outptr += OW; - outptr2 += OW; - } - - for (; h < OH; h++) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); - MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); - MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); - MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); - - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); - MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); - - MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); - MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); - MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); - MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); - - MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); - MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); - MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); - MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); - - MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); - MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); - MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); - MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); - MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - r3 += tail_step; - r4 += tail_step; - } - - filter += 25; - } -} - -void conv_stride1::do_conv_7x7_stride1( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - OW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - const float* r3 = src_ptr + IW * 3; - const float* r4 = src_ptr + IW * 4; - const float* r5 = src_ptr + IW * 5; - const float* r6 = src_ptr + IW * 6; - - const float* k0 = filter; - const float* k1 = filter + 7; - const float* k2 = filter + 14; - const float* k3 = filter + 21; - const float* k4 = filter + 28; - const float* k5 = filter + 35; - const float* k6 = filter + 42; - - for (size_t i = 0; i < OH; i++) { - int width = OW >> 2; - - rep(i, width) { - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); - MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); - - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 - MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 - MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 - MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 - MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 - MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 - MEGDNN_SIMD_TYPE _r05 = MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 - MEGDNN_SIMD_TYPE _r06 = MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); - - MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); - MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); - - MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); - MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); - MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); - MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); - MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); - MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); - MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); - - MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); - MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); - - MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); - MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); - MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); - MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); - MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); - MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); - MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); - - MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); - MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); - - MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); - MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); - MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); - MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); - MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); - MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); - MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); - - MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); - MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); - - MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); - MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); - MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); - MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); - MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); - MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); - MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); - MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); - - MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); - MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); - - MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); - MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); - MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); - MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); - MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); - MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); - MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); - MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); - - MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); - MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4); - - MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); - MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); - MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); - MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); - MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); - MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); - MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); - MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - r5 += 4; - r6 += 4; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - r3 += tail_step; - r4 += tail_step; - r5 += tail_step; - r6 += tail_step; - } - filter += 49; - } -} - -#include "src/common/simd_macro/epilogue.h" -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp deleted file mode 100644 index 24856f54..00000000 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp +++ /dev/null @@ -1,512 +0,0 @@ -/** - * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#include - -#include "./do_conv_stride2.h" -#include "midout.h" -#include "src/arm_common/conv_bias/postprocess_helper.h" -#include "src/arm_common/simd_macro/neon_helper.h" - -MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) - -using namespace megdnn; -using namespace arm_common; -using namespace fp32; -using namespace conv_stride2; - -using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; -using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; - -void conv_stride2::do_conv_2x2_stride2( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - 2 * OW + IW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - - const float* k0 = filter; - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); - rep(h, OH) { - int nn = OW >> 2; - - rep(i, nn) { - MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); - - MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 - MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 - - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); - - MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); - - MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; - MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; - - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); - - MEGDNN_SIMD_STOREU(outptr, _outp); - - r0 += 8; - r1 += 8; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - } - - filter += 4; - } -} - -void conv_stride2::do_conv_3x3_stride2( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - 2 * OW + IW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - - const float* k0 = filter; - const float* k1 = filter + 3; - const float* k2 = filter + 5; - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); - MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); - MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); - MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); - rep(h, OH) { - int nn = OW >> 2; - - rep(i, nn) { - MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); - MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); - - MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 - MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 - MEGDNN_SIMD_TYPE _r02 = - MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 - - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); - - MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); - MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); - - MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; - MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); - - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); - - MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); - MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); - - MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; - MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); - - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); - _outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); - - MEGDNN_SIMD_STOREU(outptr, _outp); - - r0 += 8; - r1 += 8; - r2 += 8; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - } - - filter += 9; - } -} - -void conv_stride2::do_conv_5x5_stride2( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - 2 * OW + IW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - const float* r3 = src_ptr + IW * 3; - const float* r4 = src_ptr + IW * 4; - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); - MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); - MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); - MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); - MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); - MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); - MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); - - for (size_t i = 0; i < OH; i++) { - int nn = OW >> 2; - - rep(i, nn) { - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); - MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); - MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 - MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 - MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 - MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 - MEGDNN_SIMD_TYPE _r02 = - MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 - MEGDNN_SIMD_TYPE _r03 = - MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 - MEGDNN_SIMD_TYPE _r04 = - MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 - - MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); - MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); - MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; - MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; - MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; - MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); - MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); - MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); - - MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); - MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); - MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; - MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; - MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; - MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); - MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); - MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); - - MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); - MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); - MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; - MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; - MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; - MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); - MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); - MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); - - MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); - MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); - MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; - MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; - MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; - MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; - MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); - MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); - MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r0 += 8; - r1 += 8; - r2 += 8; - r3 += 8; - r4 += 8; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - r3 += tail_step; - r4 += tail_step; - } - - filter += 25; - } -} - -void conv_stride2::do_conv_7x7_stride2( - const float* src, const float* filter, float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { - const size_t tail_step = IW - 2 * OW + IW; - - rep(ic, IC) { - const float* src_ptr = src + IW * IH * ic; - float* outptr = dst; - - const float* r0 = src_ptr; - const float* r1 = src_ptr + IW; - const float* r2 = src_ptr + IW * 2; - const float* r3 = src_ptr + IW * 3; - const float* r4 = src_ptr + IW * 4; - const float* r5 = src_ptr + IW * 5; - const float* r6 = src_ptr + IW * 6; - - const float* k0 = filter; - const float* k1 = filter + 7; - const float* k2 = filter + 14; - const float* k3 = filter + 21; - const float* k4 = filter + 28; - const float* k5 = filter + 35; - const float* k6 = filter + 42; - - for (size_t i = 0; i < OH; i++) { - int nn = OW >> 2; - - rep(i, nn) { - MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - - MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); - MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); - - MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); - MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); - MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 - MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 - MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 - MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 - MEGDNN_SIMD_TYPE _r02 = - MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 - MEGDNN_SIMD_TYPE _r03 = - MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 - MEGDNN_SIMD_TYPE _r04 = - MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 - MEGDNN_SIMD_TYPE _r05 = - MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 - MEGDNN_SIMD_TYPE _r06 = - MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); - - MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); - MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); - - MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); - MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); - MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; - MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; - MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; - MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; - MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); - MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); - MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); - MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); - MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); - - MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); - MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); - - MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); - MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); - MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; - MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; - MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; - MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; - MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); - MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); - MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); - MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); - MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); - - MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); - MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); - - MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); - MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); - MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; - MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; - MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; - MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; - MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); - MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); - MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); - MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); - MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); - - MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); - MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); - - MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); - MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); - MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; - MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; - MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; - MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; - MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); - MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); - MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); - MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); - MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); - - MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); - MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); - - MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); - MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); - MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; - MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; - MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; - MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; - MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); - MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); - MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); - MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); - MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); - - MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); - MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); - - MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); - MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); - MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; - MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; - MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; - MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; - MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); - MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); - MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); - MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); - MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); - - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); - _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); - - MEGDNN_SIMD_STOREU(outptr, _sum); - - r0 += 8; - r1 += 8; - r2 += 8; - r3 += 8; - r4 += 8; - r5 += 8; - r6 += 8; - outptr += 4; - } - - r0 += tail_step; - r1 += tail_step; - r2 += tail_step; - r3 += tail_step; - r4 += tail_step; - r5 += tail_step; - r6 += tail_step; - } - filter += 49; - } -} -// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 0b82b96d..5e9d2cfc 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -28,7 +28,6 @@ #include "include/megdnn/oprs/nn.h" #include "src/arm_common/conv_bias/f16/algos.h" -#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/int8/stride1.h" #include "src/arm_common/conv_bias/int8/stride2.h" #include "src/arm_common/conv_bias/quint8/stride1.h" @@ -69,14 +68,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; #endif - AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; - AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; - AlgoF32DirectNCHW44 f32_direct_nchw44; - - AlgoF32Direct f32_direct; - AlgoF32DirectStride2 f32_direct_stride2; - AlgoF32DirectStride1 f32_direct_stride1; - AlgoI8x8x16Direct i8x8x16_direct; AlgoI8x8x16Stride2 i8x8x16_stride2; AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; @@ -127,14 +118,6 @@ public: m_direct_algos.emplace_back(&i8x8x16_stride2); m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); - m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); - m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); - m_direct_algos.emplace_back(&f32_direct_nchw44); - - m_direct_algos.emplace_back(&f32_direct_stride1); - m_direct_algos.emplace_back(&f32_direct_stride2); - m_direct_algos.emplace_back(&f32_direct); - static CpuOprDelegationStorage<2> storage; auto matmul_opr = storage.get(); using MatmulFormat = param::MatrixMul::Format; @@ -145,22 +128,6 @@ public: if (is_fallback_or_naive(algo)) continue; for (uint32_t tile_size : {16, 8, 24, 32}) { - refhold.emplace_back(new AlgoFP32WinogradF23_4x4( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF63_4x4( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); //! uncomment this when low precision mode is done #if 0 refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( @@ -175,27 +142,6 @@ public: m_winograd_algos.emplace_back(refhold.back().get()); } } - matmul_algos = static_cast(matmul_opr) - ->select_algo_type( - {AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); - for (auto&& algo : matmul_algos) { - if (is_fallback_or_naive(algo)) - continue; - for (uint32_t tile_size : {16, 8, 24, 32}) { - refhold.emplace_back(new AlgoFP32WinogradF63( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF54( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF45( - static_cast(algo), - tile_size)); - m_winograd_algos.emplace_back(refhold.back().get()); - } - } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC matmul_algos = static_cast(matmul_opr) diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 65c19f49..e6ae4046 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -49,15 +49,6 @@ private: class AlgoS8DirectNCHWNCHW44; class AlgoQU8DirectStride1; class AlgoQU8DirectStride2; - class AlgoFP32WinogradF23_4x4; - class AlgoFP32WinogradF63; - class AlgoFP32WinogradF63_4x4; - class AlgoFP32WinogradF54; - class AlgoFP32WinogradF45; - - class AlgoFP32WinogradF23_4x4_NCHW44; - class AlgoFP32WinogradF63_4x4_NCHW44; - class AlgoFP32WinogradF73_4x4_NCHW44; class AlgoS8ChanWiseStride1NCHW44; class AlgoS8ChanWiseStride2NCHW44; @@ -78,12 +69,6 @@ private: class AlgoDotS8Direct_NCHW44; #endif - class AlgoF32Direct; - class AlgoF32DirectStride1; - class AlgoF32DirectStride2; - class AlgoF32DirectNCHWNCHW44; - class AlgoF32ChannelWiseNCHW44; - class AlgoF32DirectNCHW44; class AlgoI8x8x16Direct; class AlgoI8x8x16Stride2; diff --git a/dnn/src/fallback/conv_bias/direct/multi_thread_common.h b/dnn/src/fallback/conv_bias/direct/multi_thread_common.h index 4303c408..d53ab3f8 100644 --- a/dnn/src/fallback/conv_bias/direct/multi_thread_common.h +++ b/dnn/src/fallback/conv_bias/direct/multi_thread_common.h @@ -10,6 +10,8 @@ */ #pragma once +#include "megbrain_build_config.h" + #include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h" diff --git a/dnn/src/fallback/conv_bias/gi/block_helper.h b/dnn/src/fallback/conv_bias/gi/block_helper.h new file mode 100644 index 00000000..3bad5cb0 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/block_helper.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/block_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/common/utils.h" +namespace megdnn { +namespace { +// block_helper is used to calculate oh block size +static inline int l2_block_helper( + const int nthread, const int amount, const int size_per_unit) { + //! TODO: opt config or dynamic config l2_cache_size for different ARCH + constexpr int l2_cache_size = 256 * 1024; + const int block_per_thread = div_ceil(amount, nthread); + const int best_block = + std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); + const int max_block_num = div_ceil(block_per_thread, best_block); + const int min_block_num = std::max(max_block_num - 1, 1); + const int max_block = div_ceil(block_per_thread, max_block_num); + const int min_block = div_ceil(block_per_thread, min_block_num); + const int max_loss = std::abs(max_block_num * max_block - block_per_thread); + const int min_loss = std::abs(min_block_num * min_block - block_per_thread); + int block = max_loss > min_loss ? min_block : max_block; + return block; +} + +} // namespace +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp similarity index 93% rename from dnn/src/arm_common/conv_bias/fp32/algos.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/algos.cpp index e65869fa..b9b106e4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/algos.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/algos.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,23 +10,22 @@ * implied. */ -#include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/direct.h" -#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" -#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/conv_bias/img2col_helper.h" -#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/gi/fp32/algos.h" #include "src/common/opr_delegate.h" #include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/direct/multi_thread_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct.h" +#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" +#include "src/fallback/conv_bias/gi/fp32/do_conv_stride2.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/postprocess_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32) +MIDOUT_DECL(megdnn_fallback_winograd_fp32) using namespace megdnn; -using namespace arm_common; +using namespace fallback; /* ======================= AlgoFP32WinogradF23_4x4 ======================== */ @@ -34,10 +33,10 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 0, 0) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; - using Strategy = winograd::winograd_2x3_4x4_f; + using Strategy = winograd::winograd_gi_2x3_4x4_f; using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = @@ -62,8 +61,8 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( } MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( - AlgoFP32WinogradF23_4x4, winograd::winograd_2x3_4x4_f, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); + AlgoFP32WinogradF23_4x4, winograd::winograd_gi_2x3_4x4_f, + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); /* ======================= AlgoFP32WinogradF63 ======================== */ @@ -71,7 +70,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 1, 0) { using Strategy = winograd::winograd_6x3_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = @@ -95,7 +94,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF54 ======================== */ @@ -103,7 +102,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 2, 0) { using Strategy = winograd::winograd_5x4_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = @@ -127,7 +126,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF45 ======================== */ @@ -135,7 +134,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 3, 0) { using Strategy = winograd::winograd_4x5_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = @@ -159,7 +158,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF63_4x4 ======================== */ @@ -167,7 +166,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { + MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 4, 0) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_6x3_4x4_f; @@ -197,7 +196,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ @@ -206,7 +205,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); MIDOUT_BEGIN( - megdnn_arm_common_winograd_fp32, + megdnn_fallback_winograd_fp32, midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; @@ -236,7 +235,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ @@ -245,7 +244,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); MIDOUT_BEGIN( - megdnn_arm_common_winograd_fp32, + megdnn_fallback_winograd_fp32, midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; @@ -276,7 +275,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ @@ -284,7 +283,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MIDOUT_BEGIN( - megdnn_arm_common_winograd_fp32, + megdnn_fallback_winograd_fp32, midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; @@ -314,14 +313,14 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); + megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); /* ===================== direct algo ===================== */ -MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); +MIDOUT_DECL(megdnn_fallback_conv_bias_f32_kimpl); bool ConvBiasImpl::AlgoF32Direct::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; auto SH = fm.stride[0], SW = fm.stride[1]; @@ -341,7 +340,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable( return false; } size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = fallback::MultithreadDirectConvCommon::get_bundle( param, large_group); @@ -426,7 +425,7 @@ SmallVector ConvBiasImpl::AlgoF32Direct::get_kimpls( SmallVector ConvBiasImpl::AlgoF32Direct::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { return get_kimpls(param); } MIDOUT_END(); @@ -435,7 +434,7 @@ SmallVector ConvBiasImpl::AlgoF32Direct::dispatch_kerns( /* ===================== stride-1 algo ===================== */ bool ConvBiasImpl::AlgoF32DirectStride1::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && @@ -452,7 +451,7 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = fallback::MultithreadDirectConvCommon::get_bundle_stride( @@ -548,7 +547,7 @@ SmallVector ConvBiasImpl::AlgoF32DirectStride1::get_kimpl SmallVector ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 2) { return get_kimpls(param); } MIDOUT_END(); @@ -559,7 +558,7 @@ SmallVector ConvBiasImpl::AlgoF32DirectStride1::dispatch_ bool ConvBiasImpl::AlgoF32DirectStride2::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && @@ -575,7 +574,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( } size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = fallback::MultithreadDirectConvCommon::get_bundle_stride( @@ -670,7 +669,7 @@ SmallVector ConvBiasImpl::AlgoF32DirectStride2::get_kimpl SmallVector ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { + MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 2) { return get_kimpls(param); } MIDOUT_END(); diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/fallback/conv_bias/gi/fp32/algos.h similarity index 91% rename from dnn/src/arm_common/conv_bias/fp32/algos.h rename to dnn/src/fallback/conv_bias/gi/fp32/algos.h index b2c7d53e..c142123d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/algos.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/algos.h + * \file dnn/src/fallback/conv_bias/gi/fp32/algos.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,11 +12,11 @@ #pragma once -#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { public: AlgoFP32WinogradF23_4x4( @@ -31,7 +31,7 @@ public: } AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { @@ -50,7 +50,7 @@ public: return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { @@ -67,7 +67,7 @@ public: } AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { @@ -86,7 +86,7 @@ public: return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F54_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { @@ -105,7 +105,7 @@ public: return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F45_FP32) }; //===================== NCHW44 Winograd Support =====================// @@ -124,7 +124,7 @@ public: } AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { @@ -142,7 +142,7 @@ public: } AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) }; class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { @@ -160,7 +160,7 @@ public: } AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) }; // ================================================================= // @@ -180,7 +180,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_FP32) }; class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { @@ -199,7 +199,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD1_FP32) }; class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { @@ -218,7 +218,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD2_FP32) }; class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { @@ -238,7 +238,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW44_FP32) }; class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { @@ -258,7 +258,7 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW_NCHW44_FP32) }; class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { @@ -277,10 +277,10 @@ public: ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; } - MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) + MEGDNN_DECL_ALGO_TYPE(GI_COMMON_CHWNWISE_NCHW44_F32) }; -} // namespace arm_common +} // namespace fallback } // namespace megdnn #undef MEGDNN_WINOGRAD_ALGO_FUN_DECLARE diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp similarity index 60% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp index 2da295c8..919a8d8c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,29 +10,22 @@ * implied. */ -#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/utils.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #pragma GCC diagnostic ignored "-Wunused-parameter" using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { -#if defined(__ARM_FEATURE_FMA) -#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) -#else -#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) -#endif - template -static inline void shift_src(float32x4_t rsrc[3][4]) { - float32x4_t t[4]; +static inline void shift_src(GI_FLOAT32_t rsrc[3][4]) { + GI_FLOAT32_t t[4]; t[0] = rsrc[0][(shift + 0) % 4]; t[1] = rsrc[0][(shift + 1) % 4]; @@ -63,9 +56,9 @@ static inline void shift_src(float32x4_t rsrc[3][4]) { } template -static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { +static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { if (bias_mode == BiasMode::BIAS) { - return vld1q_f32(bias); + return GiLoadFloat32(bias); } else { return init; } @@ -76,35 +69,35 @@ struct compute_element { template static inline void call( const float*& src0, const float*& src1, const float*& src2, float*& dst, - const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], const Op& op) { + const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], + GI_FLOAT32_t rfilter[3][3], const Op& op) { #define RSRC(i, j) rsrc[i][((j) + bw) % 4] - float32x4_t rdst = load_bias(bias, init); + GI_FLOAT32_t rdst = load_bias(bias, init); if (has_top) { - RSRC(0, 3) = vld1q_f32(src0 + 8); + RSRC(0, 3) = GiLoadFloat32(src0 + 8); } - { RSRC(1, 3) = vld1q_f32(src1 + 8); } + { RSRC(1, 3) = GiLoadFloat32(src1 + 8); } if (has_bottom) { - RSRC(2, 3) = vld1q_f32(src2 + 8); + RSRC(2, 3) = GiLoadFloat32(src2 + 8); } if (has_top) { - rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]); - rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]); - rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]); + rdst = GiMlaqFloat32(rdst, RSRC(0, 0), rfilter[0][0]); + rdst = GiMlaqFloat32(rdst, RSRC(0, 1), rfilter[0][1]); + rdst = GiMlaqFloat32(rdst, RSRC(0, 2), rfilter[0][2]); } { - rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]); - rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]); - rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]); + rdst = GiMlaqFloat32(rdst, RSRC(1, 0), rfilter[1][0]); + rdst = GiMlaqFloat32(rdst, RSRC(1, 1), rfilter[1][1]); + rdst = GiMlaqFloat32(rdst, RSRC(1, 2), rfilter[1][2]); } if (has_bottom) { - rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]); - rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]); - rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]); + rdst = GiMlaqFloat32(rdst, RSRC(2, 0), rfilter[2][0]); + rdst = GiMlaqFloat32(rdst, RSRC(2, 1), rfilter[2][1]); + rdst = GiMlaqFloat32(rdst, RSRC(2, 2), rfilter[2][2]); } - vst1q_f32(dst, op(rdst)); + GiStoreFloat32(dst, op(rdst)); if (has_top) { src0 += 4; @@ -131,27 +124,27 @@ template struct compute_element_right { template static inline void call( - float*& dst, const float*& bias, const float32x4_t& init, - float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { - float32x4_t rdst = load_bias(bias, init); + float*& dst, const float*& bias, const GI_FLOAT32_t& init, + GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { + GI_FLOAT32_t rdst = load_bias(bias, init); if (has_top) { - rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]); - rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]); - rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]); + rdst = GiMlaqFloat32(rdst, rsrc[0][0], rfilter[0][0]); + rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][1]); + rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][2]); } { - rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]); - rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]); - rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]); + rdst = GiMlaqFloat32(rdst, rsrc[1][0], rfilter[1][0]); + rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][1]); + rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][2]); } if (has_bottom) { - rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]); - rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]); - rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]); + rdst = GiMlaqFloat32(rdst, rsrc[2][0], rfilter[2][0]); + rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][1]); + rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][2]); } - vst1q_f32(dst, op(rdst)); + GiStoreFloat32(dst, op(rdst)); dst += 4; bias += 4; @@ -162,24 +155,24 @@ template struct compute_element_right_pad { template static inline void call( - float*& dst, const float*& bias, const float32x4_t& init, - float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { - float32x4_t rdst = load_bias(bias, init); + float*& dst, const float*& bias, const GI_FLOAT32_t& init, + GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { + GI_FLOAT32_t rdst = load_bias(bias, init); if (has_top) { - rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]); - rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]); + rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][0]); + rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][1]); } { - rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]); - rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]); + rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][0]); + rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][1]); } if (has_bottom) { - rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]); - rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]); + rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][0]); + rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][1]); } - vst1q_f32(dst, op(rdst)); + GiStoreFloat32(dst, op(rdst)); dst += 4; bias += 4; } @@ -190,22 +183,22 @@ struct compute_row { template static inline void call( const float*& src0, const float*& src1, const float*& src2, float*& dst, - const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], int W, const Op& op) { + const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], + GI_FLOAT32_t rfilter[3][3], int W, const Op& op) { if (has_top) { - rsrc[0][0] = vdupq_n_f32(0); - rsrc[0][1] = vld1q_f32(src0 + 0); - rsrc[0][2] = vld1q_f32(src0 + 4); + rsrc[0][0] = GiZeroFloat32(); + rsrc[0][1] = GiLoadFloat32(src0 + 0); + rsrc[0][2] = GiLoadFloat32(src0 + 4); } { - rsrc[1][0] = vdupq_n_f32(0); - rsrc[1][1] = vld1q_f32(src1 + 0); - rsrc[1][2] = vld1q_f32(src1 + 4); + rsrc[1][0] = GiZeroFloat32(); + rsrc[1][1] = GiLoadFloat32(src1 + 0); + rsrc[1][2] = GiLoadFloat32(src1 + 4); } if (has_bottom) { - rsrc[2][0] = vdupq_n_f32(0); - rsrc[2][1] = vld1q_f32(src2 + 0); - rsrc[2][2] = vld1q_f32(src2 + 4); + rsrc[2][0] = GiZeroFloat32(); + rsrc[2][1] = GiLoadFloat32(src2 + 0); + rsrc[2][2] = GiLoadFloat32(src2 + 4); } int w = 0; @@ -256,27 +249,27 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( int W) { Op op; - float32x4_t init = vdupq_n_f32(0); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } const float* src0 = src - W * 4; const float* src1 = src; const float* src2 = src + W * 4; - float32x4_t rfilter[3][3]; - rfilter[0][0] = vld1q_f32(filter + 0); - rfilter[0][1] = vld1q_f32(filter + 4); - rfilter[0][2] = vld1q_f32(filter + 8); - rfilter[1][0] = vld1q_f32(filter + 12); - rfilter[1][1] = vld1q_f32(filter + 16); - rfilter[1][2] = vld1q_f32(filter + 20); - rfilter[2][0] = vld1q_f32(filter + 24); - rfilter[2][1] = vld1q_f32(filter + 28); - rfilter[2][2] = vld1q_f32(filter + 32); - - float32x4_t rsrc[3][4]; + GI_FLOAT32_t rfilter[3][3]; + rfilter[0][0] = GiLoadFloat32(filter + 0); + rfilter[0][1] = GiLoadFloat32(filter + 4); + rfilter[0][2] = GiLoadFloat32(filter + 8); + rfilter[1][0] = GiLoadFloat32(filter + 12); + rfilter[1][1] = GiLoadFloat32(filter + 16); + rfilter[1][2] = GiLoadFloat32(filter + 20); + rfilter[2][0] = GiLoadFloat32(filter + 24); + rfilter[2][1] = GiLoadFloat32(filter + 28); + rfilter[2][2] = GiLoadFloat32(filter + 32); + + GI_FLOAT32_t rsrc[3][4]; compute_row::call( src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h similarity index 81% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h index 77e193d8..e94d0a55 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,11 +12,11 @@ #pragma once -#include "src/arm_common/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace channel_wise_nchw44_float { template @@ -25,7 +25,7 @@ void do_conv_kern_3x3_stride1_padding1( int W); } // namespace channel_wise_nchw44_float -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp similarity index 73% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp index 46700ba4..5a635179 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,29 +10,22 @@ * implied. */ -#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/utils.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #pragma GCC diagnostic ignored "-Wunused-parameter" using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { -#if defined(__ARM_FEATURE_FMA) -#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) -#else -#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) -#endif - template -static inline void shift_src(float32x4_t rsrc[6]) { - float32x4_t t[6]; +static inline void shift_src(GI_FLOAT32_t rsrc[6]) { + GI_FLOAT32_t t[6]; t[0] = rsrc[(shift + 0) % 6]; t[1] = rsrc[(shift + 1) % 6]; @@ -48,18 +41,18 @@ static inline void shift_src(float32x4_t rsrc[6]) { rsrc[5] = t[5]; } -static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { - rfilter[0] = vld1q_f32(filter + 0); - rfilter[1] = vld1q_f32(filter + 4); - rfilter[2] = vld1q_f32(filter + 8); - rfilter[3] = vld1q_f32(filter + 12); - rfilter[4] = vld1q_f32(filter + 16); +static inline void load_filter(const float* filter, GI_FLOAT32_t rfilter[5]) { + rfilter[0] = GiLoadFloat32(filter + 0); + rfilter[1] = GiLoadFloat32(filter + 4); + rfilter[2] = GiLoadFloat32(filter + 8); + rfilter[3] = GiLoadFloat32(filter + 12); + rfilter[4] = GiLoadFloat32(filter + 16); } template -static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { +static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { if (bias_mode == BiasMode::BIAS) { - return vld1q_f32(bias); + return GiLoadFloat32(bias); } else { return init; } @@ -69,27 +62,28 @@ template static inline void call( - const float*& src, float*& dst, const float*& bias, const float32x4_t& init, - float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { + const float*& src, float*& dst, const float*& bias, + const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], + const Op& op) { #define RSRC(i) rsrc[((i) + bw) % 6] - float32x4_t rdst; + GI_FLOAT32_t rdst; if (need_load_bias) { rdst = load_bias(bias, init); } else { - rdst = vld1q_f32(dst); + rdst = GiLoadFloat32(dst); } - RSRC(5) = vld1q_f32(src + 12); + RSRC(5) = GiLoadFloat32(src + 12); - rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]); - rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]); - rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]); - rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]); - rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]); + rdst = GiMlaqFloat32(rdst, RSRC(0), rfilter[0]); + rdst = GiMlaqFloat32(rdst, RSRC(1), rfilter[1]); + rdst = GiMlaqFloat32(rdst, RSRC(2), rfilter[2]); + rdst = GiMlaqFloat32(rdst, RSRC(3), rfilter[3]); + rdst = GiMlaqFloat32(rdst, RSRC(4), rfilter[4]); if (need_do_op) { rdst = op(rdst); } - vst1q_f32(dst, rdst); + GiStoreFloat32(dst, rdst); src += 4; dst += 4; @@ -110,29 +104,29 @@ template static inline void call( - float*& dst, const float*& bias, const float32x4_t& init, - float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { - float32x4_t rdst; + float*& dst, const float*& bias, const GI_FLOAT32_t& init, + GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], const Op& op) { + GI_FLOAT32_t rdst; if (need_load_bias) { rdst = load_bias(bias, init); } else { - rdst = vld1q_f32(dst); + rdst = GiLoadFloat32(dst); } - rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]); - rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]); - rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]); + rdst = GiMlaqFloat32(rdst, rsrc[0 + padding], rfilter[0]); + rdst = GiMlaqFloat32(rdst, rsrc[1 + padding], rfilter[1]); + rdst = GiMlaqFloat32(rdst, rsrc[2 + padding], rfilter[2]); if (padding < 2) { - rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]); + rdst = GiMlaqFloat32(rdst, rsrc[3 + padding], rfilter[3]); } if (padding < 1) { - rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]); + rdst = GiMlaqFloat32(rdst, rsrc[4 + padding], rfilter[4]); } if (need_do_op) { rdst = op(rdst); } - vst1q_f32(dst, rdst); + GiStoreFloat32(dst, rdst); dst += 4; bias += 4; @@ -143,13 +137,13 @@ template struct compute_row_src_1x5 { template static inline void call( - const float* src, float* dst, const float* bias, const float32x4_t& init, - float32x4_t rsrc[6], float32x4_t rfilter[5], int W, const Op& op) { - rsrc[0] = vdupq_n_f32(0); - rsrc[1] = vdupq_n_f32(0); - rsrc[2] = vld1q_f32(src + 0); - rsrc[3] = vld1q_f32(src + 4); - rsrc[4] = vld1q_f32(src + 8); + const float* src, float* dst, const float* bias, const GI_FLOAT32_t& init, + GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], int W, const Op& op) { + rsrc[0] = GiZeroFloat32(); + rsrc[1] = GiZeroFloat32(); + rsrc[2] = GiLoadFloat32(src + 0); + rsrc[3] = GiLoadFloat32(src + 4); + rsrc[4] = GiLoadFloat32(src + 8); int w = 0; @@ -190,8 +184,8 @@ struct compute_row { template static inline void call( const float*& src, float*& dst, const float* filter, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[6], float32x4_t rfilter[5], int W, - const Op& op) { + const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], + int W, const Op& op) { if (top_padding < 1) { load_filter(filter + 0, rfilter); compute_row_src_1x5::call( @@ -235,13 +229,13 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( int W) { Op op; - float32x4_t init = vdupq_n_f32(0); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } - float32x4_t rsrc[6]; - float32x4_t rfilter[5]; + GI_FLOAT32_t rsrc[6]; + GI_FLOAT32_t rfilter[5]; compute_row<2, 0, bias_mode>::call( src, dst, filter, bias, init, rsrc, rfilter, W, op); diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h similarity index 81% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h index d3bd5fc3..9042bce9 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,11 +12,11 @@ #pragma once -#include "src/arm_common/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace channel_wise_nchw44_float { template @@ -25,7 +25,7 @@ void do_conv_kern_5x5_stride1_padding2( int W); } // namespace channel_wise_nchw44_float -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp similarity index 96% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp index 60f2fdb2..6d444d8b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,14 +10,14 @@ * implied. */ -#include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" +#include "src/fallback/conv_bias/gi/fp32/algos.h" +#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "midout.h" using namespace megdnn; -using namespace arm_common; +using namespace fallback; using conv_fun = std::function -void load_vec(float32x4_t* dst, const float* src); +void load_vec(GI_FLOAT32_t* dst, const float* src); -#define cb(i) dst[i] = vld1q_f32(src + i * 4); -#define LOAD_MACRO(n) \ - template <> \ - inline void load_vec(float32x4_t * dst, const float* src) { \ - UNROLL_CALL_NOWRAPPER(n, cb); \ +#define cb(i) dst[i] = GiLoadFloat32(src + i * 4); +#define LOAD_MACRO(n) \ + template <> \ + inline void load_vec(GI_FLOAT32_t * dst, const float* src) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ } LOAD_MACRO(2); LOAD_MACRO(3); @@ -46,14 +45,14 @@ LOAD_MACRO(9); #undef LOAD_MACRO template -void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter); +void compute_vec(GI_FLOAT32_t& dst, GI_FLOAT32_t* src, GI_FLOAT32_t* filter); -#define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]); -#define COMPUTE_MACRO(n) \ - template <> \ - inline void compute_vec( \ - float32x4_t & dst, float32x4_t * src, float32x4_t * filter) { \ - UNROLL_CALL_NOWRAPPER(n, cb); \ +#define cb(i) dst = GiMlaqFloat32(dst, src[i], filter[i]); +#define COMPUTE_MACRO(n) \ + template <> \ + inline void compute_vec( \ + GI_FLOAT32_t & dst, GI_FLOAT32_t * src, GI_FLOAT32_t * filter) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ } COMPUTE_MACRO(2); COMPUTE_MACRO(3); @@ -64,20 +63,20 @@ COMPUTE_MACRO(5); template struct load_bias_vec; -#define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4); +#define cb_bias(i) dst[i] = GiLoadFloat32((bptr) + i * 4); #define cb_init(i) dst[i] = init; -#define INIT_BIAS_MACRO(n) \ - template \ - struct load_bias_vec { \ - static void impl( \ - float32x4_t* dst, const float32x4_t& init, const float* bptr) { \ - if (bias_mode == BiasMode::BIAS) { \ - UNROLL_CALL_NOWRAPPER(n, cb_bias); \ - } else { \ - UNROLL_CALL_NOWRAPPER(n, cb_init); \ - } \ - } \ +#define INIT_BIAS_MACRO(n) \ + template \ + struct load_bias_vec { \ + static void impl( \ + GI_FLOAT32_t* dst, const GI_FLOAT32_t& init, const float* bptr) { \ + if (bias_mode == BiasMode::BIAS) { \ + UNROLL_CALL_NOWRAPPER(n, cb_bias); \ + } else { \ + UNROLL_CALL_NOWRAPPER(n, cb_init); \ + } \ + } \ }; INIT_BIAS_MACRO(1); @@ -91,7 +90,7 @@ INIT_BIAS_MACRO(4); #define COMPUTE_PADDING_KERNEL() \ do { \ int iw = ow * stride - PW; \ - float32x4_t result; \ + GI_FLOAT32_t result; \ load_bias_vec::impl(&result, init, bias + oh * OW * 4 + ow * 4); \ for (int kh = 0; kh < fh; kh++) { \ if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ @@ -100,7 +99,8 @@ INIT_BIAS_MACRO(4); if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ continue; \ const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ - result = vmlaq_f32(result, kernel[kh * fh + kw], vld1q_f32(sptr)); \ + result = GiMlaqFloat32( \ + result, kernel[kh * fh + kw], GiLoadFloat32(sptr)); \ } \ } \ float* output = dst + oh * OW * 4 + ow * 4; \ @@ -113,7 +113,7 @@ struct PaddingCompute { const float* src, const float* bias, float* dst, const int fh, const int stride, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW, - const float32x4_t* kernel, const float32x4_t& init) { + const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { size_t oh_start = (PH + stride - 1) / stride; size_t ow_start = (PW + stride - 1) / stride; size_t oh_end = (IH + PH - fh) / stride + 1; @@ -148,7 +148,7 @@ struct PaddingComputeK3P1 { static void compute( const float* src, const float* bias, float* dst, const size_t stride, const size_t IH, const size_t IW, const size_t OH, const size_t OW, - const float32x4_t* kernel, const float32x4_t& init) { + const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { constexpr size_t PH = 1, PW = 1, FH = 3; size_t oh_start = (PH + stride - 1) / stride; size_t ow_start = (PW + stride - 1) / stride; @@ -162,39 +162,39 @@ struct PaddingComputeK3P1 { Op op; // line one left { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl(&result, init, bias); - result = vmlaq_f32(result, kernel[4], vld1q_f32(src)); - result = vmlaq_f32(result, kernel[5], vld1q_f32(src + 4)); - result = vmlaq_f32(result, kernel[7], vld1q_f32(src + IW * 4)); - result = vmlaq_f32(result, kernel[8], vld1q_f32(src + IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(src)); + result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(src + 4)); + result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(src + IW * 4)); + result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(src + IW * 4 + 4)); float* output = dst; op(result, output); } // line one mid for (size_t ow = ow_start; ow < ow_end; ow++) { int iw = ow * stride - PW; - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl(&result, init, bias + ow * 4); const float* sptr = src + iw * 4; - result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + 8)); - result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + IW * 4 + 8)); + result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(sptr + 8)); + result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(sptr + IW * 4 + 8)); float* output = dst + ow * 4; op(result, output); } // line one right if (OW != ow_end) { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl(&result, init, bias + (OW - 1) * 4); const float* sptr = src + (ow_end * stride - PW) * 4; - result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); float* output = dst + ow_end * 4; op(result, output); } @@ -203,30 +203,36 @@ struct PaddingComputeK3P1 { int ih = oh * stride - PH; // left { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl(&result, init, bias + oh * OW * 4); const float* sptr = src + ih * IW * 4; - result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4)); - result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + 2 * IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32( + result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32( + result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4)); + result = GiMlaqFloat32( + result, kernel[8], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); float* output = dst + oh * OW * 4; op(result, output); } // right if (OW != ow_end) { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl( &result, init, bias + oh * OW * 4 + (OW - 1) * 4); const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4; - result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + 2 * IW * 4)); - result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32( + result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32( + result, kernel[6], GiLoadFloat32(sptr + 2 * IW * 4)); + result = GiMlaqFloat32( + result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); float* output = dst + oh * OW * 4 + ow_end * 4; op(result, output); } @@ -235,43 +241,47 @@ struct PaddingComputeK3P1 { if (OH != oh_end) { size_t oh = OH - 1; { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl(&result, init, bias + oh * OW * 4); const float* sptr = src + (oh_end * stride - PH) * IW * 4; - result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32( + result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); float* output = dst + oh_end * OW * 4; op(result, output); } // last line mid for (size_t ow = ow_start; ow < ow_end; ow++) { int iw = ow * stride - PW; - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl( &result, init, bias + oh * OW * 4 + ow * 4); const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4; - result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8)); - result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 8)); + result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 8)); + result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32( + result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32( + result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 8)); float* output = dst + oh_end * OW * 4 + ow * 4; op(result, output); } // last line right if (OW != ow_end) { - float32x4_t result; + GI_FLOAT32_t result; load_bias_vec::impl( &result, init, bias + oh * OW * 4 + (OW - 1) * 4); const float* sptr = src + (oh_end * stride - PH) * IW * 4 + (ow_end * stride - PW) * 4; - result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); - result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); - result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); + result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); + result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); + result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); + result = GiMlaqFloat32( + result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); float* output = dst + oh_end * OW * 4 + ow_end * 4; op(result, output); } @@ -286,12 +296,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( const float* src, const float* filter, const float* bias, float* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { - float32x4_t kernel[4]; + GI_FLOAT32_t kernel[4]; load_vec<4>(kernel, filter); Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } size_t oh_start = PH; size_t ow_start = PW; @@ -315,12 +325,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( size_t iw = ow - ow_start; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][4]; + GI_FLOAT32_t dst_v[2][4]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[3][5]; + GI_FLOAT32_t src_v[3][5]; load_vec<5>(src_v[0], input); COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); load_vec<5>(src_v[1], input + IW * 4); @@ -338,12 +348,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( size_t iw = ow - ow_start; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2]; + GI_FLOAT32_t dst_v[2]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[3][2]; + GI_FLOAT32_t src_v[3][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]); load_vec<2>(src_v[1], input + IW * 4); @@ -363,10 +373,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( size_t iw = ow - ow_start; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[1][4]; + GI_FLOAT32_t dst_v[1][4]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[2][5]; + GI_FLOAT32_t src_v[2][5]; load_vec<5>(src_v[0], input); COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); load_vec<5>(src_v[1], input + IW * 4); @@ -379,10 +389,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( size_t iw = ow - ow_start; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[2][2]; + GI_FLOAT32_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); load_vec<2>(src_v[1], input + IW * 4); @@ -405,12 +415,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( return; } - float32x4_t kernel[9]; + GI_FLOAT32_t kernel[9]; load_vec<9>(kernel, filter); Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } size_t oh_start = PH; size_t ow_start = PW; @@ -428,12 +438,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][4]; + GI_FLOAT32_t dst_v[2][4]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[2][6]; + GI_FLOAT32_t src_v[2][6]; load_vec<6>(src_v[0], input); compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]); @@ -472,12 +482,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2]; + GI_FLOAT32_t dst_v[2]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[2][3]; + GI_FLOAT32_t src_v[2][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); load_vec<3>(src_v[1], input + IW * 4); @@ -500,10 +510,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[4]; + GI_FLOAT32_t dst_v[4]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[2][6]; + GI_FLOAT32_t src_v[2][6]; load_vec<6>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]); @@ -526,10 +536,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[3][3]; + GI_FLOAT32_t src_v[3][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); load_vec<3>(src_v[1], input + IW * 4); @@ -553,9 +563,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( } Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } size_t oh_start = PH; size_t ow_start = PW; @@ -564,7 +574,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( if (PH || PW) { PaddingCompute::compute( src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW, - reinterpret_cast(filter), init); + reinterpret_cast(filter), init); } size_t oh = oh_start; for (; oh + 1 < oh_end; oh += 2) { @@ -574,13 +584,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][2]; + GI_FLOAT32_t dst_v[2][2]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t kernel[2][5]; - float32x4_t src_v[2][6]; + GI_FLOAT32_t kernel[2][5]; + GI_FLOAT32_t src_v[2][6]; #define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ load_vec<5>(kernel0, filter + i * 5 * 4); \ load_vec<6>(src, input + i * IW * 4); \ @@ -613,13 +623,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][1]; + GI_FLOAT32_t dst_v[2][1]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t kernel[2][5]; - float32x4_t src_v[2][5]; + GI_FLOAT32_t kernel[2][5]; + GI_FLOAT32_t src_v[2][5]; #define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ load_vec<5>(kernel0, filter + i * 5 * 4); \ load_vec<6>(src, input + i * IW * 4); \ @@ -652,11 +662,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[1][2]; + GI_FLOAT32_t dst_v[1][2]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); - float32x4_t kernel[2][5]; - float32x4_t src_v[2][6]; + GI_FLOAT32_t kernel[2][5]; + GI_FLOAT32_t src_v[2][6]; #define COMPUTE_5X5_2(i, dst, src, kernel) \ load_vec<5>(kernel, filter + i * 5 * 4); \ load_vec<6>(src, input + i * IW * 4); \ @@ -679,11 +689,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( size_t iw = ow - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t kernel[2][5]; - float32x4_t src_v[2][5]; + GI_FLOAT32_t kernel[2][5]; + GI_FLOAT32_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ load_vec<5>(kernel, filter + i * 5 * 4); \ load_vec<6>(src, input + i * IW * 4); \ @@ -709,12 +719,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( const float* src, const float* filter, const float* bias, float* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { - float32x4_t kernel[4]; + GI_FLOAT32_t kernel[4]; load_vec<4>(kernel, filter); Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } size_t oh_start = (PH + 1) / 2; size_t ow_start = (PW + 1) / 2; @@ -737,10 +747,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[4]; + GI_FLOAT32_t dst_v[4]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[2][8]; + GI_FLOAT32_t src_v[2][8]; load_vec<8>(src_v[0], input); COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); load_vec<8>(src_v[1], input + IW * 4); @@ -753,10 +763,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[2][2]; + GI_FLOAT32_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); load_vec<2>(src_v[1], input + IW * 4); @@ -773,12 +783,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( const float* src, const float* filter, const float* bias, float* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { - float32x4_t kernel[9]; + GI_FLOAT32_t kernel[9]; load_vec<9>(kernel, filter); Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } size_t oh_start = (PH + 1) / 2; size_t ow_start = (PW + 1) / 2; @@ -799,12 +809,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][2]; + GI_FLOAT32_t dst_v[2][2]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[2][5]; + GI_FLOAT32_t src_v[2][5]; load_vec<5>(src_v[0], input); compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]); @@ -830,12 +840,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2]; + GI_FLOAT32_t dst_v[2]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t src_v[2][3]; + GI_FLOAT32_t src_v[2][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); load_vec<3>(src_v[1], input + IW * 4); @@ -859,10 +869,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2]; + GI_FLOAT32_t dst_v[2]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[3][5]; + GI_FLOAT32_t src_v[3][5]; load_vec<5>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]); @@ -878,10 +888,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( size_t iw = ow * 2 - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t src_v[3][3]; + GI_FLOAT32_t src_v[3][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); load_vec<3>(src_v[1], input + IW * 4); @@ -899,9 +909,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { Op op; - float32x4_t init = vdupq_n_f32(0.f); + GI_FLOAT32_t init = GiZeroFloat32(); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { - init = vld1q_f32(bias); + init = GiLoadFloat32(bias); } constexpr size_t stride = 2; size_t oh_start = (PH + stride - 1) / stride; @@ -911,7 +921,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( if (PH || PW) { PaddingCompute::compute( src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW, - reinterpret_cast(filter), init); + reinterpret_cast(filter), init); } size_t oh = oh_start; for (; oh + 1 < oh_end; oh += 2) { @@ -921,13 +931,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( size_t iw = ow * stride - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2][2]; + GI_FLOAT32_t dst_v[2][2]; load_bias_vec::impl( dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t kernel[3][5]; - float32x4_t src_v[2][7]; + GI_FLOAT32_t kernel[3][5]; + GI_FLOAT32_t src_v[2][7]; #define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ load_vec<5>(kernel0, filter + i * 5 * 4); \ load_vec<7>(src, input + i * IW * 4); \ @@ -965,13 +975,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( size_t iw = ow * stride - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v[2]; + GI_FLOAT32_t dst_v[2]; load_bias_vec::impl( &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); - float32x4_t kernel[3][5]; - float32x4_t src_v[2][5]; + GI_FLOAT32_t kernel[3][5]; + GI_FLOAT32_t src_v[2][5]; #define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ load_vec<5>(kernel0, filter + i * 5 * 4); \ load_vec<5>(src, input + i * IW * 4); \ @@ -1010,11 +1020,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( size_t iw = ow * stride - PW; const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; - float32x4_t dst_v; + GI_FLOAT32_t dst_v; load_bias_vec::impl( &dst_v, init, bias + oh * OW * 4 + ow * 4); - float32x4_t kernel[2][5]; - float32x4_t src_v[2][5]; + GI_FLOAT32_t kernel[2][5]; + GI_FLOAT32_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ load_vec<5>(kernel, filter + i * 5 * 4); \ load_vec<6>(src, input + i * IW * 4); \ diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h similarity index 87% rename from dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h rename to dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h index 617241bd..e9075805 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,11 +12,11 @@ #pragma once -#include "src/arm_common/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace channel_wise_nchw44_float { #define KERN(stride, i) \ @@ -37,7 +37,7 @@ KERN(stride2, 5) #undef KERN } // namespace channel_wise_nchw44_float -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct.cpp similarity index 60% rename from dnn/src/arm_common/conv_bias/fp32/direct.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct.cpp index c4300f1f..f5bf0862 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/direct.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/direct.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,18 +9,18 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/direct.h" +#include "src/fallback/conv_bias/gi/fp32/direct.h" #include #include "include/megdnn/oprs.h" #include "midout.h" -#include "src/arm_common/conv_bias/postprocess_helper.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" -MIDOUT_DECL(megdnn_arm_conv_f32) +#include "src/fallback/conv_bias/gi/postprocess_helper.h" +#include "src/fallback/general_intrinsic/gi_float.h" +MIDOUT_DECL(megdnn_gi_conv_f32) using namespace megdnn; -using namespace arm_common; +using namespace fallback; using namespace fp32; using namespace conv_bias; @@ -34,65 +34,65 @@ struct do_pixel_proxy { const int ow); }; -#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); -#define LOAD_OUT \ - if (width < 4) { \ - auto load_less_4 = [](float* dst, float32x4_t& data) { \ - if (width == 1u) { \ - UNROLL_CALL_NOWRAPPER(1, cb_load); \ - } else if (width == 2u) { \ - UNROLL_CALL_NOWRAPPER(2, cb_load); \ - } else if (width == 3u) { \ - UNROLL_CALL_NOWRAPPER(3, cb_load); \ - } \ - }; \ - if (height >= 1) \ - load_less_4(dst + 0 * OW, out0); \ - if (height >= 2) \ - load_less_4(dst + 1 * OW, out1); \ - if (height >= 3) \ - load_less_4(dst + 2 * OW, out2); \ - if (height >= 4) \ - load_less_4(dst + 3 * OW, out3); \ - } else { \ - if (height > 0) \ - out0 = vld1q_f32(dst + 0 * OW); \ - if (height > 1) \ - out1 = vld1q_f32(dst + 1 * OW); \ - if (height > 2) \ - out2 = vld1q_f32(dst + 2 * OW); \ - if (height > 3) \ - out3 = vld1q_f32(dst + 3 * OW); \ - } -#define cb_store(i) vst1q_lane_f32(dst + i, data, i); -#define STORE_OUT \ +#define cb_load(i) data = GiLd1qLaneFloat32(dst + i, data, i); +#define LOAD_OUT \ if (width < 4) { \ - auto store_less_4 = [](float* dst, float32x4_t& data) { \ + auto load_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ if (width == 1u) { \ - UNROLL_CALL_NOWRAPPER(1, cb_store); \ + UNROLL_CALL_NOWRAPPER(1, cb_load); \ } else if (width == 2u) { \ - UNROLL_CALL_NOWRAPPER(2, cb_store); \ + UNROLL_CALL_NOWRAPPER(2, cb_load); \ } else if (width == 3u) { \ - UNROLL_CALL_NOWRAPPER(3, cb_store); \ + UNROLL_CALL_NOWRAPPER(3, cb_load); \ } \ }; \ if (height >= 1) \ - store_less_4(dst + 0 * OW, out0); \ + load_less_4(dst + 0 * OW, out0); \ if (height >= 2) \ - store_less_4(dst + 1 * OW, out1); \ + load_less_4(dst + 1 * OW, out1); \ if (height >= 3) \ - store_less_4(dst + 2 * OW, out2); \ + load_less_4(dst + 2 * OW, out2); \ if (height >= 4) \ - store_less_4(dst + 3 * OW, out3); \ + load_less_4(dst + 3 * OW, out3); \ } else { \ - if (height >= 1) \ - vst1q_f32(dst + 0 * OW, out0); \ - if (height >= 2) \ - vst1q_f32(dst + 1 * OW, out1); \ - if (height >= 3) \ - vst1q_f32(dst + 2 * OW, out2); \ - if (height >= 4) \ - vst1q_f32(dst + 3 * OW, out3); \ + if (height > 0) \ + out0 = GiLoadFloat32(dst + 0 * OW); \ + if (height > 1) \ + out1 = GiLoadFloat32(dst + 1 * OW); \ + if (height > 2) \ + out2 = GiLoadFloat32(dst + 2 * OW); \ + if (height > 3) \ + out3 = GiLoadFloat32(dst + 3 * OW); \ + } +#define cb_store(i) GiStoreLane##i##Float32(dst + i, data); +#define STORE_OUT \ + if (width < 4) { \ + auto store_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ + if (width == 1u) { \ + UNROLL_CALL_NOWRAPPER(1, cb_store); \ + } else if (width == 2u) { \ + UNROLL_CALL_NOWRAPPER(2, cb_store); \ + } else if (width == 3u) { \ + UNROLL_CALL_NOWRAPPER(3, cb_store); \ + } \ + }; \ + if (height >= 1) \ + store_less_4(dst + 0 * OW, out0); \ + if (height >= 2) \ + store_less_4(dst + 1 * OW, out1); \ + if (height >= 3) \ + store_less_4(dst + 2 * OW, out2); \ + if (height >= 4) \ + store_less_4(dst + 3 * OW, out3); \ + } else { \ + if (height >= 1) \ + GiStoreFloat32(dst + 0 * OW, out0); \ + if (height >= 2) \ + GiStoreFloat32(dst + 1 * OW, out1); \ + if (height >= 3) \ + GiStoreFloat32(dst + 2 * OW, out2); \ + if (height >= 4) \ + GiStoreFloat32(dst + 3 * OW, out3); \ } template @@ -104,33 +104,33 @@ struct do_pixel_proxy<1, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 1) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 2) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 3) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); } STORE_OUT; } @@ -145,45 +145,45 @@ struct do_pixel_proxy<2, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 1) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 2) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 3) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); } STORE_OUT; } @@ -198,57 +198,57 @@ struct do_pixel_proxy<3, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); - kr2 = vdupq_n_f32(filter[2 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); + kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr2); + out0 = GiMlaqFloat32(out0, inp, kr2); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 1) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr2); + out1 = GiMlaqFloat32(out1, inp, kr2); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 2) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr2); + out2 = GiMlaqFloat32(out2, inp, kr2); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); if (height > 3) - inp = vld1q_f32(src_dd + 5 * IW); + inp = GiLoadFloat32(src_dd + 5 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr2); + out3 = GiMlaqFloat32(out3, inp, kr2); } STORE_OUT; } @@ -263,69 +263,69 @@ struct do_pixel_proxy<4, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); - kr2 = vdupq_n_f32(filter[2 * FW + fw]); - kr3 = vdupq_n_f32(filter[3 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); + kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); + kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr2); + out0 = GiMlaqFloat32(out0, inp, kr2); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr3); + out0 = GiMlaqFloat32(out0, inp, kr3); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr2); + out1 = GiMlaqFloat32(out1, inp, kr2); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 1) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr3); + out1 = GiMlaqFloat32(out1, inp, kr3); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr2); + out2 = GiMlaqFloat32(out2, inp, kr2); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); if (height > 2) - inp = vld1q_f32(src_dd + 5 * IW); + inp = GiLoadFloat32(src_dd + 5 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr3); + out2 = GiMlaqFloat32(out2, inp, kr3); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr2); + out3 = GiMlaqFloat32(out3, inp, kr2); if (height > 3) - inp = vld1q_f32(src_dd + 6 * IW); + inp = GiLoadFloat32(src_dd + 6 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr3); + out3 = GiMlaqFloat32(out3, inp, kr3); } STORE_OUT; } @@ -340,81 +340,81 @@ struct do_pixel_proxy<5, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); - kr2 = vdupq_n_f32(filter[2 * FW + fw]); - kr3 = vdupq_n_f32(filter[3 * FW + fw]); - kr4 = vdupq_n_f32(filter[4 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); + kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); + kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); + kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr2); + out0 = GiMlaqFloat32(out0, inp, kr2); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr3); + out0 = GiMlaqFloat32(out0, inp, kr3); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr2); + out1 = GiMlaqFloat32(out1, inp, kr2); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr4); + out0 = GiMlaqFloat32(out0, inp, kr4); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr3); + out1 = GiMlaqFloat32(out1, inp, kr3); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr2); + out2 = GiMlaqFloat32(out2, inp, kr2); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); if (height > 1) - inp = vld1q_f32(src_dd + 5 * IW); + inp = GiLoadFloat32(src_dd + 5 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr4); + out1 = GiMlaqFloat32(out1, inp, kr4); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr3); + out2 = GiMlaqFloat32(out2, inp, kr3); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr2); + out3 = GiMlaqFloat32(out3, inp, kr2); if (height > 2) - inp = vld1q_f32(src_dd + 6 * IW); + inp = GiLoadFloat32(src_dd + 6 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr4); + out2 = GiMlaqFloat32(out2, inp, kr4); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr3); + out3 = GiMlaqFloat32(out3, inp, kr3); if (height > 3) - inp = vld1q_f32(src_dd + 7 * IW); + inp = GiLoadFloat32(src_dd + 7 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr4); + out3 = GiMlaqFloat32(out3, inp, kr4); } STORE_OUT; } @@ -429,94 +429,94 @@ struct do_pixel_proxy<6, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); - kr2 = vdupq_n_f32(filter[2 * FW + fw]); - kr3 = vdupq_n_f32(filter[3 * FW + fw]); - kr4 = vdupq_n_f32(filter[4 * FW + fw]); - kr5 = vdupq_n_f32(filter[5 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); + kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); + kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); + kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); + kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr2); + out0 = GiMlaqFloat32(out0, inp, kr2); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr3); + out0 = GiMlaqFloat32(out0, inp, kr3); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr2); + out1 = GiMlaqFloat32(out1, inp, kr2); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr4); + out0 = GiMlaqFloat32(out0, inp, kr4); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr3); + out1 = GiMlaqFloat32(out1, inp, kr3); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr2); + out2 = GiMlaqFloat32(out2, inp, kr2); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); if (height > 0) - inp = vld1q_f32(src_dd + 5 * IW); + inp = GiLoadFloat32(src_dd + 5 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr5); + out0 = GiMlaqFloat32(out0, inp, kr5); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr4); + out1 = GiMlaqFloat32(out1, inp, kr4); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr3); + out2 = GiMlaqFloat32(out2, inp, kr3); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr2); + out3 = GiMlaqFloat32(out3, inp, kr2); if (height > 1) - inp = vld1q_f32(src_dd + 6 * IW); + inp = GiLoadFloat32(src_dd + 6 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr5); + out1 = GiMlaqFloat32(out1, inp, kr5); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr4); + out2 = GiMlaqFloat32(out2, inp, kr4); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr3); + out3 = GiMlaqFloat32(out3, inp, kr3); if (height > 2) - inp = vld1q_f32(src_dd + 7 * IW); + inp = GiLoadFloat32(src_dd + 7 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr5); + out2 = GiMlaqFloat32(out2, inp, kr5); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr4); + out3 = GiMlaqFloat32(out3, inp, kr4); if (height > 3) - inp = vld1q_f32(src_dd + 8 * IW); + inp = GiLoadFloat32(src_dd + 8 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr5); + out3 = GiMlaqFloat32(out3, inp, kr5); } STORE_OUT; } @@ -531,106 +531,106 @@ struct do_pixel_proxy<7, height, width> { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, + GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, kr6, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; for (int fw = 0; fw < FW; ++fw) { const float* src_dd = src + fw; - kr0 = vdupq_n_f32(filter[0 * FW + fw]); - kr1 = vdupq_n_f32(filter[1 * FW + fw]); - kr2 = vdupq_n_f32(filter[2 * FW + fw]); - kr3 = vdupq_n_f32(filter[3 * FW + fw]); - kr4 = vdupq_n_f32(filter[4 * FW + fw]); - kr5 = vdupq_n_f32(filter[5 * FW + fw]); - kr6 = vdupq_n_f32(filter[6 * FW + fw]); + kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); + kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); + kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); + kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); + kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); + kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); + kr6 = GiBroadcastFloat32(filter[6 * FW + fw]); if (height > 0) - inp = vld1q_f32(src_dd + 0 * IW); + inp = GiLoadFloat32(src_dd + 0 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr0); + out0 = GiMlaqFloat32(out0, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 1 * IW); + inp = GiLoadFloat32(src_dd + 1 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr1); + out0 = GiMlaqFloat32(out0, inp, kr1); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr0); + out1 = GiMlaqFloat32(out1, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 2 * IW); + inp = GiLoadFloat32(src_dd + 2 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr2); + out0 = GiMlaqFloat32(out0, inp, kr2); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr1); + out1 = GiMlaqFloat32(out1, inp, kr1); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr0); + out2 = GiMlaqFloat32(out2, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 3 * IW); + inp = GiLoadFloat32(src_dd + 3 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr3); + out0 = GiMlaqFloat32(out0, inp, kr3); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr2); + out1 = GiMlaqFloat32(out1, inp, kr2); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr1); + out2 = GiMlaqFloat32(out2, inp, kr1); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr0); + out3 = GiMlaqFloat32(out3, inp, kr0); if (height > 0) - inp = vld1q_f32(src_dd + 4 * IW); + inp = GiLoadFloat32(src_dd + 4 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr4); + out0 = GiMlaqFloat32(out0, inp, kr4); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr3); + out1 = GiMlaqFloat32(out1, inp, kr3); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr2); + out2 = GiMlaqFloat32(out2, inp, kr2); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr1); + out3 = GiMlaqFloat32(out3, inp, kr1); if (height > 0) - inp = vld1q_f32(src_dd + 5 * IW); + inp = GiLoadFloat32(src_dd + 5 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr5); + out0 = GiMlaqFloat32(out0, inp, kr5); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr4); + out1 = GiMlaqFloat32(out1, inp, kr4); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr3); + out2 = GiMlaqFloat32(out2, inp, kr3); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr2); + out3 = GiMlaqFloat32(out3, inp, kr2); if (height > 0) - inp = vld1q_f32(src_dd + 6 * IW); + inp = GiLoadFloat32(src_dd + 6 * IW); if (height > 0) - out0 = vmlaq_f32(out0, inp, kr6); + out0 = GiMlaqFloat32(out0, inp, kr6); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr5); + out1 = GiMlaqFloat32(out1, inp, kr5); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr4); + out2 = GiMlaqFloat32(out2, inp, kr4); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr3); + out3 = GiMlaqFloat32(out3, inp, kr3); if (height > 1) - inp = vld1q_f32(src_dd + 7 * IW); + inp = GiLoadFloat32(src_dd + 7 * IW); if (height > 1) - out1 = vmlaq_f32(out1, inp, kr6); + out1 = GiMlaqFloat32(out1, inp, kr6); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr5); + out2 = GiMlaqFloat32(out2, inp, kr5); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr4); + out3 = GiMlaqFloat32(out3, inp, kr4); if (height > 2) - inp = vld1q_f32(src_dd + 8 * IW); + inp = GiLoadFloat32(src_dd + 8 * IW); if (height > 2) - out2 = vmlaq_f32(out2, inp, kr6); + out2 = GiMlaqFloat32(out2, inp, kr6); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr5); + out3 = GiMlaqFloat32(out3, inp, kr5); if (height > 3) - inp = vld1q_f32(src_dd + 9 * IW); + inp = GiLoadFloat32(src_dd + 9 * IW); if (height > 3) - out3 = vmlaq_f32(out3, inp, kr6); + out3 = GiMlaqFloat32(out3, inp, kr6); } STORE_OUT; } @@ -836,31 +836,31 @@ void conv_bias::kern_direct( } while (0) switch (FH) { case 1: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } MIDOUT_END(); break; case 2: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } MIDOUT_END(); break; case 3: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } MIDOUT_END(); break; case 4: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } MIDOUT_END(); break; case 5: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } MIDOUT_END(); break; case 6: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } MIDOUT_END(); break; case 7: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } MIDOUT_END(); break; } @@ -872,31 +872,31 @@ void conv_bias::kern_direct( } while (0) switch (FH) { case 1: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } MIDOUT_END(); break; case 2: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } MIDOUT_END(); break; case 3: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } MIDOUT_END(); break; case 4: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } MIDOUT_END(); break; case 5: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } MIDOUT_END(); break; case 6: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } MIDOUT_END(); break; case 7: - MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } + MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } MIDOUT_END(); break; } diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.h b/dnn/src/fallback/conv_bias/gi/fp32/direct.h similarity index 87% rename from dnn/src/arm_common/conv_bias/fp32/direct.h rename to dnn/src/fallback/conv_bias/gi/fp32/direct.h index 57941be8..fa596390 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/direct.h + * \file dnn/src/fallback/conv_bias/gi/fp32/direct.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -13,7 +13,7 @@ #include namespace megdnn { -namespace arm_common { +namespace fallback { namespace fp32 { namespace conv_bias { @@ -23,7 +23,7 @@ void kern_direct( } // namespace conv_bias } // namespace fp32 -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp similarity index 58% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp index 16e30e2c..bd2a332c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,12 +11,12 @@ * implied. */ -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/opr_impl.h" +#include "src/fallback/general_intrinsic/gi_float.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace conv_bias { template <> void pack_src_fp32_nchw44<1>( @@ -51,23 +51,23 @@ static inline void odd_even_split_iw8_even( const int src_offset = src_idx * ic_step; const int even_offset = iw_idx / 2 * ic_step; const int odd_offset = (odd_start + iw_idx / 2) * ic_step; - float32x4_t temp[8]; - temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); - temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); - temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); - temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); - temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); - temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); - temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); - temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); - vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); - vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); - vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); - vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); - vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); - vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); - vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); - vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); + GI_FLOAT32_t temp[8]; + temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); + temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); + temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); + temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); + temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); + temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); + temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); + temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); + GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[0]); + GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[2]); + GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[4]); + GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[6]); + GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[1]); + GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[3]); + GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[5]); + GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[7]); } static inline void odd_even_split_iw8_odd( @@ -77,23 +77,23 @@ static inline void odd_even_split_iw8_odd( const int src_offset = src_idx * ic_step; const int even_offset = (iw_idx + 1) / 2 * ic_step; const int odd_offset = (odd_start + iw_idx / 2) * ic_step; - float32x4_t temp[8]; - temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); - temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); - temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); - temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); - temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); - temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); - temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); - temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); - vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); - vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); - vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); - vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); - vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); - vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); - vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); - vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); + GI_FLOAT32_t temp[8]; + temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); + temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); + temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); + temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); + temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); + temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); + temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); + temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); + GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[0]); + GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[2]); + GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[4]); + GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[6]); + GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[1]); + GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[3]); + GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[5]); + GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[7]); } } // namespace @@ -104,7 +104,7 @@ void pack_src_fp32_nchw44<2>( const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { constexpr int ic_step = 4; int odd_start = megdnn::div_ceil(iw2, 2); - float32x4_t zero_v = vdupq_n_f32(0.f); + GI_FLOAT32_t zero_v = GiZeroFloat32(); MEGDNN_MARK_USED_VAR(ph); bool even_start = pw % 2 == 0; rep_step(ic_idx, ic, ic_step) { @@ -115,9 +115,10 @@ void pack_src_fp32_nchw44<2>( int iw_idx = 0; rep(idx, pw) { if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); + GiStoreFloat32( + sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); } ++iw_idx; } @@ -136,21 +137,22 @@ void pack_src_fp32_nchw44<2>( } for (; src_idx < iw; ++src_idx) { if (iw_idx % 2 == 0) { - vst1q_f32( + GiStoreFloat32( sptr_base + iw_idx / 2 * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); + GiLoadFloat32(sptr + src_idx * ic_step)); } else { - vst1q_f32( + GiStoreFloat32( sptr_base + (odd_start + iw_idx / 2) * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); + GiLoadFloat32(sptr + src_idx * ic_step)); } ++iw_idx; } rep(idx, pad_right) { if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); + GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); + GiStoreFloat32( + sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); } ++iw_idx; } @@ -163,7 +165,7 @@ void pack_src_fp32_nchw44<2>( } } // namespace conv_bias -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp index 200769d6..a9663a73 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp index c6c974a5..a06afba5 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp index 6f075a54..61a9eb99 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_NO_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp index a9728847..e6e3107e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp index ae899e2c..d3f9e5fd 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp index 94c09aea..079ccf8b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_NO_BIAS(2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp index 0047c51e..685bb919 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp index c273dede..fc69ce2f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp index 719dbd1d..979a479b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_NO_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp index 01209f9c..e904cee0 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp index 7bed53e2..254698aa 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp index 9aa190df..ccfac7a6 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_NO_BIAS(3); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp index 5cbcb78a..c58b59d2 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp index bcf92bab..ceafb966 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp index d944b02b..8ed94109 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_NO_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp index a75f159a..482465a8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp index ff9653ea..16f64558 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp index a2705bde..02a5b6db 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_NO_BIAS(5); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp index 47cbf3d7..788fca50 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp index f8fa2c29..cda9bc24 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp index c7824aad..27a20d0c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" INSTANTIATION_CONV_S1_NO_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp index fd603f3b..30112c90 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp index bd1e5c29..bc2b73cc 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp similarity index 70% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp index 0caee33c..9ae8eb3e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" INSTANTIATION_CONV_S2_NO_BIAS(7); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h similarity index 85% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index 37b893a2..ee367b8c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,16 +12,15 @@ */ #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" +#include "src/fallback/conv_bias/gi/intrinsic_helper.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { template < @@ -39,13 +38,13 @@ struct ShiftCalHelper { }; #define cb2(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32( \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ - c[1][step] = vfmaq_laneq_f32( \ + c[1][step] = GiSimdFmaLane( \ c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); -#define cb(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32( \ +#define cb(step, lane, ow_block) \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); #define SHIFT_CAL_HELPER(ow_block, remain_w) \ @@ -122,7 +121,7 @@ public: template < BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, int ow_block> -struct KerNeonXXs1Nchw44FP32 { +struct KerGiXXs1Nchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -130,7 +129,7 @@ struct KerNeonXXs1Nchw44FP32 { }; template -struct KerNeonXXs1Nchw44FP32 { +struct KerGiXXs1Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -147,20 +146,20 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, 0); - load_helper( + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); - load_helper( + src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -172,7 +171,7 @@ struct KerNeonXXs1Nchw44FP32 { }; template -struct KerNeonXXs1Nchw44FP32 { +struct KerGiXXs1Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -189,24 +188,24 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, 0); - load_helper( + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); - load_helper( + src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); - load_helper( + src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -217,7 +216,7 @@ struct KerNeonXXs1Nchw44FP32 { } }; template -struct KerNeonXXs1Nchw44FP32 { +struct KerGiXXs1Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -234,36 +233,36 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, 0); - load_helper( + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); - load_helper( + src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); - load_helper( + src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); - load_helper( + src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); - load_helper( + src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -275,7 +274,7 @@ struct KerNeonXXs1Nchw44FP32 { }; template -struct KerNeonXXs1Nchw44FP32 { +struct KerGiXXs1Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -292,46 +291,46 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, 0); - load_helper( + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][ic_step]; + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); - load_helper( + src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); - load_helper( + src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); - load_helper( + src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); - load_helper( + src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); - load_helper( + src[4] = GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); - load_helper( + src[5] = GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -352,10 +351,10 @@ void conv_bias::conv_direct_fp32_nchw44( constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; -#if MEGDNN_ARMV7 - constexpr int big_oc_step = 4; -#else +#if MEGDNN_AARCH64 constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 4; #endif constexpr int oc_step = 4; constexpr int ih_step = 1; @@ -381,9 +380,9 @@ void conv_bias::conv_direct_fp32_nchw44( switch (ow_remain) { #define cb(step) \ case step: \ - kern_big_oc_remain = KerNeonXXs1Nchw44FP32< \ + kern_big_oc_remain = KerGiXXs1Nchw44FP32< \ bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ - kern_small_oc_remain = KerNeonXXs1Nchw44FP32< \ + kern_small_oc_remain = KerGiXXs1Nchw44FP32< \ bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ break; @@ -402,7 +401,7 @@ void conv_bias::conv_direct_fp32_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs1Nchw44FP32< + KerGiXXs1Nchw44FP32< bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: impl(src + src_offset, filter + weight_offset, bias + bias_offset, dst + dst_offset, ic, ih, iw, @@ -434,7 +433,7 @@ void conv_bias::conv_direct_fp32_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs1Nchw44FP32< + KerGiXXs1Nchw44FP32< bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: impl(src + src_offset, filter + weight_offset, bias + bias_offset, dst + dst_offset, ic, ih, iw, diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h similarity index 84% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index e52654a0..f6af99b1 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,16 +12,15 @@ */ #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" +#include "src/fallback/conv_bias/gi/intrinsic_helper.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { template < @@ -39,13 +38,13 @@ struct ShiftCalHelper { }; #define cb2(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32( \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ - c[1][step] = vfmaq_laneq_f32( \ + c[1][step] = GiSimdFmaLane( \ c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); -#define cb(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32( \ +#define cb(step, lane, ow_block) \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); #define SHIFT_CAL_HELPER(ow_block, remain_w) \ @@ -122,7 +121,7 @@ public: template < BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, int ow_block> -struct KerNeonXXs2Nchw44FP32 { +struct KerGiXXs2Nchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -130,7 +129,7 @@ struct KerNeonXXs2Nchw44FP32 { }; template -struct KerNeonXXs2Nchw44FP32 { +struct KerGiXXs2Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -147,36 +146,36 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][4]; + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][4]; /////////row 0///////////// - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; /////////row 1///////////// - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -188,7 +187,7 @@ struct KerNeonXXs2Nchw44FP32 { }; template -struct KerNeonXXs2Nchw44FP32 { +struct KerGiXXs2Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -205,62 +204,62 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][4]; + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][4]; /////////row 0///////////// - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + ow_block * simd_len); - load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; /////////row 1///////////// - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + ow_block * simd_len); - load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; //////////row 2///////////// - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + ow_block * simd_len); + src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); - load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; @@ -272,7 +271,7 @@ struct KerNeonXXs2Nchw44FP32 { }; template -struct KerNeonXXs2Nchw44FP32 { +struct KerGiXXs2Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -289,7 +288,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { @@ -297,28 +296,28 @@ struct KerNeonXXs2Nchw44FP32 { const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][4]; + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][4]; // even element - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + ow_block * simd_len); - load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); - load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); + load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); - load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); + load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -337,7 +336,7 @@ struct KerNeonXXs2Nchw44FP32 { * src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 **/ template -struct KerNeonXXs2Nchw44FP32 { +struct KerGiXXs2Nchw44FP32 { static void impl( const float32_t* src_ptr_origin, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -354,7 +353,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_ic = ih * iw; const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][ow_block]; + GI_FLOAT32_t c[c_dim][ow_block]; init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { @@ -362,36 +361,36 @@ struct KerNeonXXs2Nchw44FP32 { const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { - float32x4_t src[ow_block]; - float32x4_t weight[c_dim][4]; + GI_FLOAT32_t src[ow_block]; + GI_FLOAT32_t weight[c_dim][4]; // even element - load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr + ow_block * simd_len); - load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); + load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); - load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); + load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); - load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len); + load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element - load_helper(src, src_ptr_odd, 0); - load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( + load_helper(src, src_ptr_odd, 0); + load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); - load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); + load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); - load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( + src[1] = GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len); + load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>( weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -414,10 +413,10 @@ void conv_bias::conv_direct_fp32_nchw44( constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; -#if MEGDNN_ARMV7 - constexpr int big_oc_step = 4; -#else +#if MEGDNN_AARCH64 constexpr int big_oc_step = 8; +#else + constexpr int big_oc_step = 4; #endif constexpr int oc_step = 4; constexpr int ih_step = 1; @@ -444,9 +443,9 @@ void conv_bias::conv_direct_fp32_nchw44( switch (ow_remain) { #define cb(step) \ case step: \ - kern_big_oc_remain = KerNeonXXs2Nchw44FP32< \ + kern_big_oc_remain = KerGiXXs2Nchw44FP32< \ bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ - kern_small_oc_remain = KerNeonXXs2Nchw44FP32< \ + kern_small_oc_remain = KerGiXXs2Nchw44FP32< \ bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ break; @@ -469,7 +468,7 @@ void conv_bias::conv_direct_fp32_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs2Nchw44FP32< + KerGiXXs2Nchw44FP32< bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: impl(src + src_offset, filter + weight_offset, bias + bias_offset, dst + dst_offset, ic, ih, iw, @@ -510,7 +509,7 @@ void conv_bias::conv_direct_fp32_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs2Nchw44FP32< + KerGiXXs2Nchw44FP32< bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: impl(src + src_offset, filter + weight_offset, bias + bias_offset, dst + dst_offset, ic, ih, iw, diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp index e82c464e..5e0ca627 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(2, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp index fbdb5ec7..f9e46534 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp index 97f2595c..1d8e676e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(2, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp index 3acba768..7b899fb7 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(2, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp index 1e4c7197..0e031946 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp index 03bc548f..6558f35f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(2, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp index 89bc21d5..ecb576ea 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(3, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp index fe811030..d45ec28e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp index 88cbe5f8..5719c179 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(3, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp index f2f02815..9f155db7 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(3, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp index 416e9839..8d7eca4d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp index bf3f792d..aaae519c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(3, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp index f4d38c4e..43f9fdda 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(5, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp index 1dcb60b0..0581c1d4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp index e32fbccb..fe8720a5 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(5, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp index a818401f..711be538 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(5, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp index be387827..f0401f2d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp index 64c9db59..0d10f489 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(5, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp index 6fb2e117..74e4fe1f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(7, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp index 74ad5102..53bb7085 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp index 94af0cbd..f7eacb6c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(7, 1); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp index 576213bc..5d0977a6 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BIAS(7, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp similarity index 68% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp index 58890e90..bf2b1500 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp index 4a4e0f35..ed8c9f82 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp @@ -1,6 +1,6 @@ /** * \file - * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp + * dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,6 +10,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" +#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" INSTANCE_CONV_NO_BIAS(7, 2); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h similarity index 92% rename from dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h rename to dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 07372da1..ade5e44c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,20 +11,19 @@ */ #pragma once #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" +#include "src/fallback/conv_bias/gi/intrinsic_helper.h" +#include "src/fallback/conv_bias/opr_impl.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #if MEGDNN_ARMV7 #include "src/armv7/matrix_mul/asm/common.h" #endif using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { /** @@ -50,15 +49,15 @@ struct ShiftCalHelper { }; #define cb(step) \ - c[0][step] = vfmaq_laneq_f32( \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ (step * stride + src_idx) % 4); \ - c[1][step] = vfmaq_laneq_f32( \ + c[1][step] = GiSimdFmaLane( \ c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \ (step * stride + src_idx) % 4); #define cb2(step) \ - c[0][step] = vfmaq_laneq_f32( \ + c[0][step] = GiSimdFmaLane( \ c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ (step * stride + src_idx) % 4); @@ -127,7 +126,7 @@ public: template < BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG> -struct KerNeonXXs2NchwNchw44FP32 { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -136,8 +135,7 @@ struct KerNeonXXs2NchwNchw44FP32 { template < BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, int ow_block> -struct KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -154,16 +152,16 @@ struct KerNeonXXs2NchwNchw44FP32< const int ld_weight_ic = oc_step * filter_size * filter_size; const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; + GI_FLOAT32_t c[c_dim][8]; init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; + GI_FLOAT32_t src[src_reg_size]; + GI_FLOAT32_t weight[c_dim][filter_size]; #define KERNEL_CB(step) \ - load_helper(src, src_ptr + step * iw, 0); \ - load_helper( \ + load_helper(src, src_ptr + step * iw, 0); \ + load_helper( \ weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ @@ -186,8 +184,7 @@ struct KerNeonXXs2NchwNchw44FP32< template < BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, int ow_block> -struct KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -204,16 +201,16 @@ struct KerNeonXXs2NchwNchw44FP32< const int ld_weight_ic = oc_step * filter_size * filter_size; const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; + GI_FLOAT32_t c[c_dim][8]; init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; + GI_FLOAT32_t src[src_reg_size]; + GI_FLOAT32_t weight[c_dim][filter_size]; #define KERNEL_CB(step) \ - load_helper(src, src_ptr + step * iw, 0); \ - load_helper( \ + load_helper(src, src_ptr + step * iw, 0); \ + load_helper( \ weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ @@ -233,8 +230,7 @@ struct KerNeonXXs2NchwNchw44FP32< template < BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, int ow_block> -struct KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -251,32 +247,32 @@ struct KerNeonXXs2NchwNchw44FP32< const int ld_weight_ic = oc_step * filter_size * filter_size; const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; + GI_FLOAT32_t c[c_dim][8]; init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; + GI_FLOAT32_t src[src_reg_size]; + GI_FLOAT32_t weight[c_dim][filter_size]; // row 0 - load_helper(src, src_ptr, 0); - load_helper( + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); // row 1 - load_helper(src, src_ptr + iw, 0); - load_helper( + load_helper(src, src_ptr + iw, 0); + load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); // row 2 - load_helper( + load_helper( src, src_ptr + 2 * iw, 0); - load_helper( + load_helper( weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); @@ -292,7 +288,7 @@ struct KerNeonXXs2NchwNchw44FP32< #if MEGDNN_ARMV7 template -struct KerNeonXXs2NchwNchw44FP32 { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -310,7 +306,7 @@ struct KerNeonXXs2NchwNchw44FP32 { const int ld_src_ic_skip_bytes = iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; constexpr int c_dim = OCHelper::val; - float32x4_t c[1][8]; + GI_FLOAT32_t c[1][8]; init_ocx_ow8(c, bias_ptr, oc_step); const int img_stride = ih * iw; constexpr int filter_stride = filter_size * filter_size * oc_step; @@ -464,8 +460,7 @@ struct KerNeonXXs2NchwNchw44FP32 { }; template -struct KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -483,7 +478,7 @@ struct KerNeonXXs2NchwNchw44FP32< const int ld_src_ic_skip_bytes = iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; constexpr int c_dim = OCHelper::val; - float32x4_t c[1][8]; + GI_FLOAT32_t c[1][8]; init_ocx_ow8(c, bias_ptr, oc_step); /** * c q8-q15 @@ -626,8 +621,7 @@ struct KerNeonXXs2NchwNchw44FP32< template < BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, int ow_block> -struct KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { +struct KerGiXXs2NchwNchw44FP32 { static void impl( const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, @@ -644,22 +638,22 @@ struct KerNeonXXs2NchwNchw44FP32< const int ld_weight_ic = oc_step * filter_size * filter_size; const int ld_src_ic = ih * iw; constexpr int c_dim = OCHelper::val; - float32x4_t c[c_dim][8]; + GI_FLOAT32_t c[c_dim][8]; init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - float32x4_t src[src_reg_size]; - float32x4_t weight[c_dim][filter_size]; + GI_FLOAT32_t src[src_reg_size]; + GI_FLOAT32_t weight[c_dim][filter_size]; // row 0 - load_helper(src, src_ptr, 0); - load_helper( + load_helper(src, src_ptr, 0); + load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); // row 1 - load_helper(src, src_ptr + iw, 0); - load_helper( + load_helper(src, src_ptr + iw, 0); + load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); @@ -711,9 +705,9 @@ struct ConvDirectFp32NchwNchw44 { switch (ow_remain) { #define cb(step) \ case step: \ - kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ - kern_small_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + kern_small_oc_remain = KerGiXXs2NchwNchw44FP32< \ bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \ break; @@ -731,7 +725,7 @@ struct ConvDirectFp32NchwNchw44 { ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32< + KerGiXXs2NchwNchw44FP32< bias_mode, Op, ow_step, filter_size, big_oc_step, stride, ow_step>:: impl(src + src_offset, filter + weight_offset, @@ -760,7 +754,7 @@ struct ConvDirectFp32NchwNchw44 { ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32< + KerGiXXs2NchwNchw44FP32< bias_mode, Op, ow_step, filter_size, oc_step, stride, ow_step>:: impl(src + src_offset, filter + weight_offset, @@ -819,7 +813,7 @@ struct ConvDirectFp32NchwNchw44 { switch (ow_remain) { #define cb(step) \ case step: \ - kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ break; @@ -849,7 +843,7 @@ struct ConvDirectFp32NchwNchw44 { ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32< + KerGiXXs2NchwNchw44FP32< bias_mode, Op, ow_step, filter_size, big_oc_step, stride, ow_step, CpuTag::A7_TAG>:: impl(src + src_offset, filter + weight_offset, @@ -878,7 +872,7 @@ struct ConvDirectFp32NchwNchw44 { ic_step * pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32< + KerGiXXs2NchwNchw44FP32< bias_mode, Op, ow_step, filter_size, big_oc_step, stride, ow_step>:: impl(src + src_offset, filter + weight_offset, diff --git a/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp new file mode 100644 index 00000000..03ee501e --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp @@ -0,0 +1,723 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include + +#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" +#include "src/fallback/conv_bias/gi/postprocess_helper.h" +#include "src/fallback/conv_bias/opr_impl.h" +#include "src/fallback/general_intrinsic/gi_float.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs1) + +using namespace megdnn; +using namespace fallback; +using namespace fp32; +using namespace conv_stride1; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + +void conv_stride1::do_conv_2x2_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - OW; + //! unroll of 2 + size_t ic = 0; + for (; ic + 1 < IC; ic += 2) { + const float* src_ptr = src + IW * IH * ic; + const float* src_ptr1 = src_ptr + IW * IH; + float* outptr = dst; + + const float* r00 = src_ptr; + const float* r01 = src_ptr + IW; + const float* r10 = src_ptr1; + const float* r11 = src_ptr1 + IW; + + const float* k0 = filter + ic * 4; + const float* k1 = k0 + 4; + + GI_FLOAT32_t _k0 = GiLoadFloat32(k0); + GI_FLOAT32_t _k1 = GiLoadFloat32(k1); + rep(h, OH) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _r000 = GiLoadFloat32(r00); + GI_FLOAT32_t _r010 = GiLoadFloat32(r01); + GI_FLOAT32_t _r001 = GiLoadFloat32(r00 + 1); + GI_FLOAT32_t _r011 = GiLoadFloat32(r01 + 1); + + GI_FLOAT32_t _r100 = GiLoadFloat32(r10); + GI_FLOAT32_t _r110 = GiLoadFloat32(r11); + GI_FLOAT32_t _r101 = GiLoadFloat32(r10 + 1); + GI_FLOAT32_t _r111 = GiLoadFloat32(r11 + 1); + + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + + _sum = GiVmlaqLaneFloat32LowHalf(_sum, _r000, _k0, 0); + _sum = GiVmlaqLaneFloat32LowHalf(_sum, _r001, _k0, 1); + _sum = GiMlaqLaneFloat32HighHalf(_sum, _r010, _k0, 0); + _sum = GiMlaqLaneFloat32HighHalf(_sum, _r011, _k0, 1); + + _sum = GiVmlaqLaneFloat32LowHalf(_sum, _r100, _k1, 0); + _sum = GiVmlaqLaneFloat32LowHalf(_sum, _r101, _k1, 1); + _sum = GiMlaqLaneFloat32HighHalf(_sum, _r110, _k1, 0); + _sum = GiMlaqLaneFloat32HighHalf(_sum, _r111, _k1, 1); + + GiStoreFloat32(outptr, _sum); + + r00 += 4; + r01 += 4; + r10 += 4; + r11 += 4; + outptr += 4; + } + + r00 += tail_step; + r01 += tail_step; + r10 += tail_step; + r11 += tail_step; + } + } + for (; ic < IC; ic++) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + + const float* k0 = filter + ic * 4; + + GI_FLOAT32_t _k0 = GiBroadcastFloat32(k0[0]); + GI_FLOAT32_t _k1 = GiBroadcastFloat32(k0[1]); + GI_FLOAT32_t _k2 = GiBroadcastFloat32(k0[2]); + GI_FLOAT32_t _k3 = GiBroadcastFloat32(k0[3]); + rep(h, OH) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r01 = GiLoadFloat32(r0 + 1); + GI_FLOAT32_t _r11 = GiLoadFloat32(r1 + 1); + + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + GI_FLOAT32_t _sum2; + + _sum = GiMlaqFloat32(_sum, _r00, _k0); + _sum2 = GiMultiplyFloat32(_r01, _k1); + _sum = GiMlaqFloat32(_sum, _r10, _k2); + _sum2 = GiMlaqFloat32(_sum2, _r11, _k3); + + _sum = GiAddFloat32(_sum, _sum2); + + GiStoreFloat32(outptr, _sum); + + r0 += 4; + r1 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + } + } +} + +void conv_stride1::do_conv_3x3_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + float* outptr2 = outptr + OW; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + + const float* k0 = filter; + const float* k1 = filter + 3; + const float* k2 = filter + 5; + + GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); + GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); + GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); + GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); + GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); + GI_FLOAT32_t _sum3 = GiLoadFloat32(outptr2); + GI_FLOAT32_t _sum4 = GiBroadcastFloat32(0.f); + + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); + GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); + GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); + + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); + GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); + + GI_FLOAT32_t _r20 = GiLoadFloat32(r2); + GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); + GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); + + GI_FLOAT32_t _r30 = GiLoadFloat32(r3); + GI_FLOAT32_t _r30n = GiLoadFloat32LowHalf(r3 + 4); + GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r30n, 1); + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r30n, 2); + + _sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); + _sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); + _sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); + _sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); + _sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); + _sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); + _sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); + _sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); + _sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); + + _sum3 = GiSimdFmaLane(_sum3, _r10, _k0123, 0); + _sum4 = GiSimdFmaLane(_sum4, _r11, _k0123, 1); + _sum3 = GiSimdFmaLane(_sum3, _r12, _k0123, 2); + _sum4 = GiSimdFmaLane(_sum4, _r20, _k3456, 0); + _sum3 = GiSimdFmaLane(_sum3, _r21, _k3456, 1); + _sum4 = GiSimdFmaLane(_sum4, _r22, _k3456, 2); + _sum3 = GiSimdFmaLane(_sum3, _r30, _k6789, 0); + _sum4 = GiSimdFmaLane(_sum4, _r31, _k6789, 1); + _sum3 = GiSimdFmaLane(_sum3, _r32, _k6789, 2); + + _sum1 = GiAddFloat32(_sum1, _sum2); + _sum3 = GiAddFloat32(_sum3, _sum4); + + GiStoreFloat32(outptr, _sum1); + GiStoreFloat32(outptr2, _sum3); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); + GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); + + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); + GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); + GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); + + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); + GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); + + GI_FLOAT32_t _r20 = GiLoadFloat32(r2); + GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); + GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); + + _sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); + _sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); + _sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); + _sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); + _sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); + _sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); + _sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); + _sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); + _sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); + + _sum1 = GiAddFloat32(_sum1, _sum2); + + GiStoreFloat32(outptr, _sum1); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr += 4; + } + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +void conv_stride1::do_conv_5x5_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + float* outptr2 = outptr + OW; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + + GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); + GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); + GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); + GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); + GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); + GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); + GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); + + size_t h = 0; + for (; h + 1 < OH; h += 2) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + GI_FLOAT32_t _sum2 = GiLoadFloat32(outptr2); + + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); + GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); + GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); + GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); + + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); + GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); + GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); + + GI_FLOAT32_t _r20 = GiLoadFloat32(r2); + GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); + GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); + GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); + + GI_FLOAT32_t _r30 = GiLoadFloat32(r3); + GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); + GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); + GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); + + GI_FLOAT32_t _r40 = GiLoadFloat32(r4); + GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); + GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); + GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); + GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); + + GI_FLOAT32_t _r50 = GiLoadFloat32(r5); + GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); + GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); + GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); + GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); + + _sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); + _sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); + _sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); + _sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); + _sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); + + _sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); + _sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); + _sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); + _sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); + _sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); + + _sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); + _sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); + _sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); + _sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); + _sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); + + _sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); + _sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); + _sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); + _sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); + _sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); + + _sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); + _sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); + _sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); + _sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); + _sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); + + _sum2 = GiSimdFmaLane(_sum2, _r10, _k0123, 0); + _sum2 = GiSimdFmaLane(_sum2, _r11, _k0123, 1); + _sum2 = GiSimdFmaLane(_sum2, _r12, _k0123, 2); + _sum2 = GiSimdFmaLane(_sum2, _r13, _k0123, 3); + _sum2 = GiSimdFmaLane(_sum2, _r14, _k4567, 0); + + _sum2 = GiSimdFmaLane(_sum2, _r20, _k4567, 1); + _sum2 = GiSimdFmaLane(_sum2, _r21, _k4567, 2); + _sum2 = GiSimdFmaLane(_sum2, _r22, _k4567, 3); + _sum2 = GiSimdFmaLane(_sum2, _r23, _k891011, 0); + _sum2 = GiSimdFmaLane(_sum2, _r24, _k891011, 1); + + _sum2 = GiSimdFmaLane(_sum2, _r30, _k891011, 2); + _sum2 = GiSimdFmaLane(_sum2, _r31, _k891011, 3); + _sum2 = GiSimdFmaLane(_sum2, _r32, _k12131415, 0); + _sum2 = GiSimdFmaLane(_sum2, _r33, _k12131415, 1); + _sum2 = GiSimdFmaLane(_sum2, _r34, _k12131415, 2); + + _sum2 = GiSimdFmaLane(_sum2, _r40, _k12131415, 3); + _sum2 = GiSimdFmaLane(_sum2, _r41, _k16171819, 0); + _sum2 = GiSimdFmaLane(_sum2, _r42, _k16171819, 1); + _sum2 = GiSimdFmaLane(_sum2, _r43, _k16171819, 2); + _sum2 = GiSimdFmaLane(_sum2, _r44, _k16171819, 3); + + _sum2 = GiSimdFmaLane(_sum2, _r50, _k20212223, 0); + _sum2 = GiSimdFmaLane(_sum2, _r51, _k20212223, 1); + _sum2 = GiSimdFmaLane(_sum2, _r52, _k20212223, 2); + _sum2 = GiSimdFmaLane(_sum2, _r53, _k20212223, 3); + _sum2 = GiSimdFmaLane(_sum2, _r54, _k24242424, 0); + + GiStoreFloat32(outptr, _sum); + GiStoreFloat32(outptr2, _sum2); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + outptr += 4; + outptr2 += 4; + } + + r0 += tail_step + IW; + r1 += tail_step + IW; + r2 += tail_step + IW; + r3 += tail_step + IW; + r4 += tail_step + IW; + r5 += tail_step + IW; + + outptr += OW; + outptr2 += OW; + } + + for (; h < OH; h++) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); + GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); + GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); + GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); + + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); + GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); + GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); + + GI_FLOAT32_t _r20 = GiLoadFloat32(r2); + GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); + GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); + GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); + + GI_FLOAT32_t _r30 = GiLoadFloat32(r3); + GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); + GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); + GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); + + GI_FLOAT32_t _r40 = GiLoadFloat32(r4); + GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); + GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); + GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); + GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); + + _sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); + _sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); + _sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); + _sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); + _sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); + + _sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); + _sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); + _sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); + _sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); + _sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); + + _sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); + _sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); + _sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); + _sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); + _sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); + + _sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); + _sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); + _sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); + _sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); + _sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); + + _sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); + _sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); + _sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); + _sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); + _sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); + + GiStoreFloat32(outptr, _sum); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +void conv_stride1::do_conv_7x7_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - OW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + const float* r6 = src_ptr + IW * 6; + + const float* k0 = filter; + const float* k1 = filter + 7; + const float* k2 = filter + 14; + const float* k3 = filter + 21; + const float* k4 = filter + 28; + const float* k5 = filter + 35; + const float* k6 = filter + 42; + + for (size_t i = 0; i < OH; i++) { + int width = OW >> 2; + + rep(i, width) { + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + + GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); + GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); + + GI_FLOAT32_t _r00 = GiLoadFloat32(r0); // 0 1 2 3 + GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); // 4 5 6 7 + GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 8); // 8 9 10 11 + GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); // 1 2 3 4 + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); // 2 3 4 5 + GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); // 3 4 5 6 + GI_FLOAT32_t _r05 = GiExtqFloat32(_r04, _r00n, 1); // 5 6 7 8 + GI_FLOAT32_t _r06 = GiExtqFloat32(_r04, _r00n, 2); // 6 7 8 9 + + _sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); + _sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); + _sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); + _sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); + _sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); + _sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); + _sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); + + GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); + GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); + + GI_FLOAT32_t _r10 = GiLoadFloat32(r1); + GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); + GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 8); + GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); + GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); + GI_FLOAT32_t _r15 = GiExtqFloat32(_r14, _r10n, 1); + GI_FLOAT32_t _r16 = GiExtqFloat32(_r14, _r10n, 2); + + _sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); + _sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); + _sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); + _sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); + _sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); + _sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); + _sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); + + GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); + GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); + + GI_FLOAT32_t _r20 = GiLoadFloat32(r2); + GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); + GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 8); + GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); + GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); + GI_FLOAT32_t _r25 = GiExtqFloat32(_r24, _r20n, 1); + GI_FLOAT32_t _r26 = GiExtqFloat32(_r24, _r20n, 2); + + _sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); + _sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); + _sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); + _sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); + _sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); + _sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); + _sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); + + GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); + GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); + + GI_FLOAT32_t _r30 = GiLoadFloat32(r3); + GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); + GI_FLOAT32_t _r30n = GiLoadFloat32(r3 + 8); + GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); + GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); + GI_FLOAT32_t _r35 = GiExtqFloat32(_r34, _r30n, 1); + GI_FLOAT32_t _r36 = GiExtqFloat32(_r34, _r30n, 2); + + _sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); + _sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); + _sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); + _sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); + _sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); + _sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); + _sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); + + GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); + GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); + + GI_FLOAT32_t _r40 = GiLoadFloat32(r4); + GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); + GI_FLOAT32_t _r40n = GiLoadFloat32(r4 + 8); + GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); + GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); + GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); + GI_FLOAT32_t _r45 = GiExtqFloat32(_r44, _r40n, 1); + GI_FLOAT32_t _r46 = GiExtqFloat32(_r44, _r40n, 2); + + _sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); + _sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); + _sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); + _sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); + _sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); + _sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); + _sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); + + GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); + GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); + + GI_FLOAT32_t _r50 = GiLoadFloat32(r5); + GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); + GI_FLOAT32_t _r50n = GiLoadFloat32(r5 + 8); + GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); + GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); + GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); + GI_FLOAT32_t _r55 = GiExtqFloat32(_r54, _r50n, 1); + GI_FLOAT32_t _r56 = GiExtqFloat32(_r54, _r50n, 2); + + _sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); + _sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); + _sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); + _sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); + _sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); + _sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); + _sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); + + GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); + GI_FLOAT32_t _k46474849 = + GiLd1qLaneFloat32(k6 + 4 + 2, GiLoadFloat32LowHalf(k6 + 4), 2); + + GI_FLOAT32_t _r60 = GiLoadFloat32(r6); + GI_FLOAT32_t _r64 = GiLoadFloat32(r6 + 4); + GI_FLOAT32_t _r60n = GiLoadFloat32(r6 + 8); + GI_FLOAT32_t _r61 = GiExtqFloat32(_r60, _r64, 1); + GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r64, 2); + GI_FLOAT32_t _r63 = GiExtqFloat32(_r60, _r64, 3); + GI_FLOAT32_t _r65 = GiExtqFloat32(_r64, _r60n, 1); + GI_FLOAT32_t _r66 = GiExtqFloat32(_r64, _r60n, 2); + + _sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); + _sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); + _sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); + _sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); + _sum = GiSimdFmaLane(_sum, _r64, _k46474849, 0); + _sum = GiSimdFmaLane(_sum, _r65, _k46474849, 1); + _sum = GiSimdFmaLane(_sum, _r66, _k46474849, 2); + + GiStoreFloat32(outptr, _sum); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + r4 += 4; + r5 += 4; + r6 += 4; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} + +#include "src/common/simd_macro/epilogue.h" +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h similarity index 91% rename from dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h rename to dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h index fadf8c32..5ea20649 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h + * \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -13,7 +13,7 @@ #include namespace megdnn { -namespace arm_common { +namespace fallback { namespace fp32 { namespace conv_stride1 { @@ -31,7 +31,7 @@ void do_conv_7x7_stride1( size_t OH, size_t OW, size_t IC); } // namespace conv_stride1 } // namespace fp32 -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp new file mode 100644 index 00000000..c109e181 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp @@ -0,0 +1,503 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include + +#include "./do_conv_stride2.h" +#include "midout.h" +#include "src/fallback/conv_bias/gi/postprocess_helper.h" +#include "src/fallback/general_intrinsic/gi_float.h" + +MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs2) + +using namespace megdnn; +using namespace fallback; +using namespace fp32; +using namespace conv_stride2; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; + +void conv_stride2::do_conv_2x2_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + + const float* k0 = filter; + + GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); + rep(h, OH) { + int nn = OW >> 2; + + rep(i, nn) { + GI_FLOAT32_t _outp = GiLoadFloat32(outptr); + + GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); + + GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 + GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 + + _outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); + _outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); + + GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); + + GI_FLOAT32_t _r10 = _r1.val[0]; + GI_FLOAT32_t _r11 = _r1.val[1]; + + _outp = GiSimdFmaLane(_outp, _r10, _k0123, 2); + _outp = GiSimdFmaLane(_outp, _r11, _k0123, 3); + + GiStoreFloat32(outptr, _outp); + + r0 += 8; + r1 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + } + + filter += 4; + } +} + +void conv_stride2::do_conv_3x3_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + + const float* k0 = filter; + const float* k1 = filter + 3; + const float* k2 = filter + 5; + + GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); + GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); + GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); + GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); + rep(h, OH) { + int nn = OW >> 2; + + rep(i, nn) { + GI_FLOAT32_t _outp = GiLoadFloat32(outptr); + + GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); + GI_FLOAT32_V2_t _r0n = GiLd2qFloat32(r0 + 8); + + GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 + GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0n.val[0], 1); // 2 4 6 8 + + _outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); + _outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); + _outp = GiSimdFmaLane(_outp, _r02, _k0123, 2); + + GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); + GI_FLOAT32_V2_t _r1n = GiLd2qFloat32(r1 + 8); + + GI_FLOAT32_t _r10 = _r1.val[0]; + GI_FLOAT32_t _r11 = _r1.val[1]; + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1n.val[0], 1); + + _outp = GiSimdFmaLane(_outp, _r10, _k3456, 0); + _outp = GiSimdFmaLane(_outp, _r11, _k3456, 1); + _outp = GiSimdFmaLane(_outp, _r12, _k3456, 2); + + GI_FLOAT32_V2_t _r2 = GiLd2qFloat32(r2); + GI_FLOAT32_V2_t _r2n = GiLd2qFloat32(r2 + 8); + + GI_FLOAT32_t _r20 = _r2.val[0]; + GI_FLOAT32_t _r21 = _r2.val[1]; + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2n.val[0], 1); + + _outp = GiSimdFmaLane(_outp, _r20, _k6789, 0); + _outp = GiSimdFmaLane(_outp, _r21, _k6789, 1); + _outp = GiSimdFmaLane(_outp, _r22, _k6789, 2); + + GiStoreFloat32(outptr, _outp); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + } + + filter += 9; + } +} + +void conv_stride2::do_conv_5x5_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + + GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); + GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); + GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); + GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); + GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); + GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); + GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); + + for (size_t i = 0; i < OH; i++) { + int nn = OW >> 2; + + rep(i, nn) { + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + + GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); + GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); + GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 + GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 + GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 + GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 + GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 + GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 + + GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); + GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); + GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; + GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; + GI_FLOAT32_t _r10 = _r10_02461357.val[0]; + GI_FLOAT32_t _r11 = _r10_02461357.val[1]; + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); + GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); + GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); + + GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); + GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); + GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; + GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; + GI_FLOAT32_t _r20 = _r20_02461357.val[0]; + GI_FLOAT32_t _r21 = _r20_02461357.val[1]; + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); + GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); + GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); + + GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); + GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); + GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; + GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; + GI_FLOAT32_t _r30 = _r30_02461357.val[0]; + GI_FLOAT32_t _r31 = _r30_02461357.val[1]; + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); + GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); + GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); + + GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); + GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); + GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; + GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; + GI_FLOAT32_t _r40 = _r40_02461357.val[0]; + GI_FLOAT32_t _r41 = _r40_02461357.val[1]; + GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); + GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); + GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); + + _sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); + _sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); + _sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); + _sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); + _sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); + + _sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); + _sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); + _sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); + _sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); + _sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); + + _sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); + _sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); + _sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); + _sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); + _sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); + + _sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); + _sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); + _sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); + _sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); + _sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); + + _sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); + _sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); + _sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); + _sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); + _sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); + + GiStoreFloat32(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + } + + filter += 25; + } +} + +void conv_stride2::do_conv_7x7_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { + const size_t tail_step = IW - 2 * OW + IW; + + rep(ic, IC) { + const float* src_ptr = src + IW * IH * ic; + float* outptr = dst; + + const float* r0 = src_ptr; + const float* r1 = src_ptr + IW; + const float* r2 = src_ptr + IW * 2; + const float* r3 = src_ptr + IW * 3; + const float* r4 = src_ptr + IW * 4; + const float* r5 = src_ptr + IW * 5; + const float* r6 = src_ptr + IW * 6; + + const float* k0 = filter; + const float* k1 = filter + 7; + const float* k2 = filter + 14; + const float* k3 = filter + 21; + const float* k4 = filter + 28; + const float* k5 = filter + 35; + const float* k6 = filter + 42; + + for (size_t i = 0; i < OH; i++) { + int nn = OW >> 2; + + rep(i, nn) { + GI_FLOAT32_t _sum = GiLoadFloat32(outptr); + + GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); + GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); + + GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); + GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); + GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 + GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 + GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 + GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 + GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 + GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 + GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 + GI_FLOAT32_t _r05 = GiExtqFloat32(_r01, _r0_9111315, 2); // 5 7 9 11 + GI_FLOAT32_t _r06 = GiExtqFloat32(_r00, _r0_8101214, 3); // 6 8 10 12 + + _sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); + _sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); + _sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); + _sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); + _sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); + _sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); + _sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); + + GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); + GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); + + GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); + GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); + GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; + GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; + GI_FLOAT32_t _r10 = _r10_02461357.val[0]; + GI_FLOAT32_t _r11 = _r10_02461357.val[1]; + GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); + GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); + GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); + GI_FLOAT32_t _r15 = GiExtqFloat32(_r11, _r1_9111315, 2); + GI_FLOAT32_t _r16 = GiExtqFloat32(_r10, _r1_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); + _sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); + _sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); + _sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); + _sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); + _sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); + _sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); + + GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); + GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); + + GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); + GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); + GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; + GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; + GI_FLOAT32_t _r20 = _r20_02461357.val[0]; + GI_FLOAT32_t _r21 = _r20_02461357.val[1]; + GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); + GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); + GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); + GI_FLOAT32_t _r25 = GiExtqFloat32(_r21, _r2_9111315, 2); + GI_FLOAT32_t _r26 = GiExtqFloat32(_r20, _r2_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); + _sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); + _sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); + _sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); + _sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); + _sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); + _sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); + + GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); + GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); + + GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); + GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); + GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; + GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; + GI_FLOAT32_t _r30 = _r30_02461357.val[0]; + GI_FLOAT32_t _r31 = _r30_02461357.val[1]; + GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); + GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); + GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); + GI_FLOAT32_t _r35 = GiExtqFloat32(_r31, _r3_9111315, 2); + GI_FLOAT32_t _r36 = GiExtqFloat32(_r30, _r3_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); + _sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); + _sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); + _sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); + _sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); + _sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); + _sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); + + GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); + GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); + + GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); + GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); + GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; + GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; + GI_FLOAT32_t _r40 = _r40_02461357.val[0]; + GI_FLOAT32_t _r41 = _r40_02461357.val[1]; + GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); + GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); + GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); + GI_FLOAT32_t _r45 = GiExtqFloat32(_r41, _r4_9111315, 2); + GI_FLOAT32_t _r46 = GiExtqFloat32(_r40, _r4_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); + _sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); + _sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); + _sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); + _sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); + _sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); + _sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); + + GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); + GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); + + GI_FLOAT32_V2_t _r50_02461357 = GiLd2qFloat32(r5); + GI_FLOAT32_V2_t _r50nx2 = GiLd2qFloat32(r5 + 8); + GI_FLOAT32_t _r5_8101214 = _r50nx2.val[0]; + GI_FLOAT32_t _r5_9111315 = _r50nx2.val[1]; + GI_FLOAT32_t _r50 = _r50_02461357.val[0]; + GI_FLOAT32_t _r51 = _r50_02461357.val[1]; + GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r5_8101214, 1); + GI_FLOAT32_t _r53 = GiExtqFloat32(_r51, _r5_9111315, 1); + GI_FLOAT32_t _r54 = GiExtqFloat32(_r50, _r5_8101214, 2); + GI_FLOAT32_t _r55 = GiExtqFloat32(_r51, _r5_9111315, 2); + GI_FLOAT32_t _r56 = GiExtqFloat32(_r50, _r5_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); + _sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); + _sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); + _sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); + _sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); + _sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); + _sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); + + GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); + GI_FLOAT32_t _k45464748 = GiLoadFloat32(k6 + 3); + + GI_FLOAT32_V2_t _r60_02461357 = GiLd2qFloat32(r6); + GI_FLOAT32_V2_t _r60nx2 = GiLd2qFloat32(r6 + 8); + GI_FLOAT32_t _r6_8101214 = _r60nx2.val[0]; + GI_FLOAT32_t _r6_9111315 = _r60nx2.val[1]; + GI_FLOAT32_t _r60 = _r60_02461357.val[0]; + GI_FLOAT32_t _r61 = _r60_02461357.val[1]; + GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r6_8101214, 1); + GI_FLOAT32_t _r63 = GiExtqFloat32(_r61, _r6_9111315, 1); + GI_FLOAT32_t _r64 = GiExtqFloat32(_r60, _r6_8101214, 2); + GI_FLOAT32_t _r65 = GiExtqFloat32(_r61, _r6_9111315, 2); + GI_FLOAT32_t _r66 = GiExtqFloat32(_r60, _r6_8101214, 3); + + _sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); + _sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); + _sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); + _sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); + _sum = GiSimdFmaLane(_sum, _r64, _k45464748, 1); + _sum = GiSimdFmaLane(_sum, _r65, _k45464748, 2); + _sum = GiSimdFmaLane(_sum, _r66, _k45464748, 3); + + GiStoreFloat32(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + r4 += 8; + r5 += 8; + r6 += 8; + outptr += 4; + } + + r0 += tail_step; + r1 += tail_step; + r2 += tail_step; + r3 += tail_step; + r4 += tail_step; + r5 += tail_step; + r6 += tail_step; + } + filter += 49; + } +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h similarity index 91% rename from dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h rename to dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h index 74c28586..1d824ca8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h + * \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -13,7 +13,7 @@ #include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace fp32 { namespace conv_stride2 { void do_conv_2x2_stride2( @@ -30,7 +30,7 @@ void do_conv_7x7_stride2( size_t OH, size_t OW, size_t IC); } // namespace conv_stride2 } // namespace fp32 -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp similarity index 95% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp index 508a0a21..5037b962 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,21 +11,21 @@ */ #include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/block_helper.h" -#include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" +#include "src/fallback/conv_bias/gi/block_helper.h" +#include "src/fallback/conv_bias/gi/fp32/algos.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "midout.h" using namespace megdnn; -using namespace arm_common; +using namespace fallback; using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) +MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw44_stride1) namespace { static inline size_t get_perthread_cache_bytes( @@ -156,7 +156,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable( size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN( - megdnn_arm_common_conv_bias_fp32_nchw44_stride1, + megdnn_fallback_conv_bias_fp32_nchw44_stride1, midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } @@ -175,7 +175,7 @@ SmallVector ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_k // shape runtime #define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ MIDOUT_BEGIN( \ - megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ + megdnn_fallback_conv_bias_fp32_nchw44_stride1, \ midout_iv(#filter #bias_mode #stride #op##_hash)) { \ do_conv_fun = do_conv_kern; \ } \ diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h similarity index 86% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h rename to dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h index c25fd850..6136b5e1 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_stride1_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -10,10 +10,10 @@ * implied. */ -#include "src/arm_common/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace conv_bias { template @@ -28,5 +28,5 @@ void pack_src_fp32_nchw44( const int pad_top, const int pad_bottom, const int ic, const int ic_stride); } // namespace conv_bias -} // namespace arm_common +} // namespace fallback } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp similarity index 96% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp index 9efeff10..7fc33115 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -1,6 +1,6 @@ /** * \file - dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp + dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,21 +12,21 @@ */ #include "megdnn/oprs.h" -#include "src/arm_common/conv_bias/fp32/algos.h" -#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" #include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/gi/fp32/algos.h" +#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" #include "midout.h" using namespace megdnn; -using namespace arm_common; +using namespace fallback; using conv_fun = std::function; -MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44) +MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw_nchw44) namespace { static inline int block_helper( const int nthread, const int amount, const int per_unit_bytes) { @@ -195,7 +195,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN( - megdnn_arm_common_conv_bias_fp32_nchw_nchw44, + megdnn_fallback_conv_bias_fp32_nchw_nchw44, midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } @@ -214,7 +214,7 @@ SmallVector ConvBiasImpl::AlgoF32DirectNCHWNCHW44:: // shape runtime #define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ MIDOUT_BEGIN( \ - megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ + megdnn_fallback_conv_bias_fp32_nchw_nchw44, \ midout_iv(#stride #filter #bias_mode #op##_hash)) { \ do_conv_fun = do_conv_kern; \ } \ diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h similarity index 79% rename from dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h rename to dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h index 51477e02..5c94af31 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h + * \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,15 +11,14 @@ */ #pragma once #include "megdnn/arch.h" -#include "src/arm_common/conv_bias/intrinsic_helper.h" -#include "src/arm_common/conv_bias/opr_impl.h" -#include "src/arm_common/elemwise_helper/elemwise_op.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/gi/intrinsic_helper.h" +#include "src/fallback/conv_bias/opr_impl.h" +#include "src/fallback/elemwise_helper/elemwise_op.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace fp32_direct_nchw_nchw44 { static inline void pack_weight_fp32_nchw_nchw44( @@ -34,8 +33,8 @@ static inline void pack_weight_fp32_nchw_nchw44( for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { - float32x4_t vsrc = vld1q_f32(in_ptr_oc); - vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); + GI_FLOAT32_t vsrc = GiLoadFloat32(in_ptr_oc); + GiStoreFloat32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); in_ptr_oc += oc_step; } dst_ptr_oc += oc_step; @@ -51,6 +50,6 @@ void conv_direct_fp32_nchw_nchw44( const int, const int); } // namespace fp32_direct_nchw_nchw44 -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/filter_transform.h b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h similarity index 93% rename from dnn/src/arm_common/conv_bias/fp32/filter_transform.h rename to dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h index f263afbb..daf26e10 100644 --- a/dnn/src/arm_common/conv_bias/fp32/filter_transform.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h + * \file dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,14 +11,13 @@ #pragma once #include "megdnn/opr_param_defs.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/utils.h" namespace megdnn { -namespace arm_common { +namespace fallback { template struct FilterTransform6X3 { @@ -65,8 +64,8 @@ struct FilterTransform6X3 { Vector g1 = Vector::load(fptr + 3); Vector g2 = Vector::load(fptr + 6 - 1); - float32x4_t zeros = vdupq_n_f32(0.0f); - g2.value = vextq_f32(g2.value, zeros, 1); + GI_FLOAT32_t zeros = GiZeroFloat32(); + g2.value = GiExtqFloat32(g2.value, zeros, 1); #define cb(i) Vector wd##i; UNROLL_CALL_NOWRAPPER(8, cb); @@ -106,7 +105,6 @@ struct FilterTransform6X3 { } #else - #define cb(i) \ do { \ mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ @@ -128,7 +126,7 @@ struct FilterTransform6X3 { mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ mid_buf1 += 8; \ } while (0); -#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) +#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) float* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); @@ -154,7 +152,7 @@ struct FilterTransform6X3 { #undef FILTER_TRANSFORM #undef GET_VECTOR_ELEM -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/fp32/helper.h b/dnn/src/fallback/conv_bias/gi/fp32/helper.h new file mode 100644 index 00000000..ebab04f2 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/fp32/helper.h @@ -0,0 +1,196 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/fp32/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once +#include "src/common/unroll_macro.h" +#include "src/fallback/general_intrinsic/gi_float.h" + +namespace megdnn { +namespace fallback { +inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { + GI_FLOAT32_V2_t a0, a1; + a0.val[0] = GiLoadFloat32(src + 0 * lda); + a0.val[1] = GiLoadFloat32(src + 1 * lda); + a1.val[0] = GiLoadFloat32(src + 2 * lda); + a1.val[1] = GiLoadFloat32(src + 3 * lda); + GI_FLOAT32_V2_t b0 = GiZipqFloat32(a0.val[0], a1.val[0]); + GI_FLOAT32_V2_t b1 = GiZipqFloat32(a0.val[1], a1.val[1]); + GI_FLOAT32_V2_t c0 = GiZipqFloat32(b0.val[0], b1.val[0]); + GI_FLOAT32_V2_t c1 = GiZipqFloat32(b0.val[1], b1.val[1]); + GiStoreFloat32(dst + 0 * ldb, c0.val[0]); + GiStoreFloat32(dst + 1 * ldb, c0.val[1]); + GiStoreFloat32(dst + 2 * ldb, c1.val[0]); + GiStoreFloat32(dst + 3 * ldb, c1.val[1]); +} +} // namespace fallback +} // namespace megdnn + +#define MATRIX_MUL4x4(sum, a, b) \ + sum##0 = GiMlaqLowLaneFloat32(sum##0, b##0, a##0, 0); \ + sum##0 = GiMlaqLowLaneFloat32(sum##0, b##1, a##0, 1); \ + sum##0 = GiMlaqHighLaneFloat32(sum##0, b##2, a##0, 2); \ + sum##0 = GiMlaqHighLaneFloat32(sum##0, b##3, a##0, 3); \ + sum##1 = GiMlaqLowLaneFloat32(sum##1, b##0, a##1, 0); \ + sum##1 = GiMlaqLowLaneFloat32(sum##1, b##1, a##1, 1); \ + sum##1 = GiMlaqHighLaneFloat32(sum##1, b##2, a##1, 2); \ + sum##1 = GiMlaqHighLaneFloat32(sum##1, b##3, a##1, 3); \ + sum##2 = GiMlaqLowLaneFloat32(sum##2, b##0, a##2, 0); \ + sum##2 = GiMlaqLowLaneFloat32(sum##2, b##1, a##2, 1); \ + sum##2 = GiMlaqHighLaneFloat32(sum##2, b##2, a##2, 2); \ + sum##2 = GiMlaqHighLaneFloat32(sum##2, b##3, a##2, 3); \ + sum##3 = GiMlaqLowLaneFloat32(sum##3, b##0, a##3, 0); \ + sum##3 = GiMlaqLowLaneFloat32(sum##3, b##1, a##3, 1); \ + sum##3 = GiMlaqHighLaneFloat32(sum##3, b##2, a##3, 2); \ + sum##3 = GiMlaqHighLaneFloat32(sum##3, b##3, a##3, 3); + +#define CONCAT(a, idx) a##idx + +#if MEGDNN_AARCH64 +//! ret and a are type Vector +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ + auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ + auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ + auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ + auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ + auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ + auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ + auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ + CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b4.val[0]), \ + GiReinterpretqFloat32ToS64(b6.val[0]))); \ + CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b2.val[0]))); \ + CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[0]), \ + GiReinterpretqFloat32ToS64(b6.val[0]))); \ + CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b4.val[1]), \ + GiReinterpretqFloat32ToS64(b6.val[1]))); \ + CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b2.val[1]))); \ + CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b4.val[1]), \ + GiReinterpretqFloat32ToS64(b6.val[1]))); \ + CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b5.val[0]), \ + GiReinterpretqFloat32ToS64(b7.val[0]))); \ + CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b1.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b5.val[0]), \ + GiReinterpretqFloat32ToS64(b7.val[0]))); \ + CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b1.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b5.val[1]), \ + GiReinterpretqFloat32ToS64(b7.val[1]))); \ + CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b1.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b5.val[1]), \ + GiReinterpretqFloat32ToS64(b7.val[1]))); \ + } while (0); + +#define TRANSPOSE_8x3(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); + +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[0]), \ + GiReinterpretqFloat32ToS64(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[0]), \ + GiReinterpretqFloat32ToS64(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); \ + CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b0.val[1]), \ + GiReinterpretqFloat32ToS64(b1.val[1]))); \ + CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ + GiReinterpretqFloat32ToS64(b2.val[1]), \ + GiReinterpretqFloat32ToS64(b3.val[1]))); + +#else +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = \ + GiCombineFloat32(GiGetLowFloat32(b0.val[0]), GiGetLowFloat32(b1.val[0])); \ + CONCAT(ret, 1).value.val[0] = GiCombineFloat32( \ + GiGetHighFloat32(b0.val[0]), GiGetHighFloat32(b1.val[0])); \ + CONCAT(ret, 2).value.val[0] = \ + GiCombineFloat32(GiGetLowFloat32(b0.val[1]), GiGetLowFloat32(b1.val[1])); \ + CONCAT(ret, 3).value.val[0] = GiCombineFloat32( \ + GiGetHighFloat32(b0.val[1]), GiGetHighFloat32(b1.val[1])); \ + CONCAT(ret, 0).value.val[1] = \ + GiCombineFloat32(GiGetLowFloat32(b2.val[0]), GiGetLowFloat32(b3.val[0])); \ + CONCAT(ret, 1).value.val[1] = GiCombineFloat32( \ + GiGetHighFloat32(b2.val[0]), GiGetHighFloat32(b3.val[0])); \ + CONCAT(ret, 2).value.val[1] = \ + GiCombineFloat32(GiGetLowFloat32(b2.val[1]), GiGetLowFloat32(b3.val[1])); \ + CONCAT(ret, 3).value.val[1] = GiCombineFloat32( \ + GiGetHighFloat32(b2.val[1]), GiGetHighFloat32(b3.val[1])); + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy.h b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h similarity index 82% rename from dnn/src/arm_common/conv_bias/fp32/strategy.h rename to dnn/src/fallback/conv_bias/gi/fp32/strategy.h index e942074b..b754c4c6 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy.h +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy.h + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -11,14 +11,15 @@ #pragma once -#include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/fallback/conv_bias/gi/postprocess_helper.h" #include "src/fallback/conv_bias/winograd/winograd.h" namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, winograd_2x3_4x4_f) +MEGDNN_REG_WINOGRAD_STRATEGY( + float, float, float, float, 2, 3, 4, 4, winograd_gi_2x3_4x4_f) MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) @@ -37,7 +38,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY( MEGDNN_REG_WINOGRAD_STRATEGY( float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp similarity index 81% rename from dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp index 4b211e30..7f656129 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F23) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { struct InputTransform2X3 { @@ -40,15 +39,15 @@ struct InputTransform2X3 { const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 4; ++ico) { if (ic + ico < IC) { - auto v0 = vld1q_f32(input_ptr); - auto v1 = vld1q_f32(input_ptr + IW); - auto v2 = vld1q_f32(input_ptr + IW * 2); - auto v3 = vld1q_f32(input_ptr + IW * 3); - - vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); - vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); - vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); - vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); + auto v0 = GiLoadFloat32(input_ptr); + auto v1 = GiLoadFloat32(input_ptr + IW); + auto v2 = GiLoadFloat32(input_ptr + IW * 2); + auto v3 = GiLoadFloat32(input_ptr + IW * 3); + + GiStoreFloat32(patch + ico * 4 * alpha + 0 * 4, v0); + GiStoreFloat32(patch + ico * 4 * alpha + 1 * 4, v1); + GiStoreFloat32(patch + ico * 4 * alpha + 2 * 4, v2); + GiStoreFloat32(patch + ico * 4 * alpha + 3 * 4, v3); input_ptr += IH * IW; } } @@ -197,18 +196,18 @@ struct OutputTransform2X3 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { -MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) -void winograd_2x3_4x4_f::filter( +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_gi_2x3_4x4_f) +void winograd_gi_2x3_4x4_f::filter( const float* filter, float* filter_transform_buf, float* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, size_t oc_end) { constexpr int alpha = 2 + 3 - 1; //! G * g * GT - float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, + GI_FLOAT32_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, g3{0, 0, 1, 0}; - float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, + GI_FLOAT32_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, gt3{0, 0, 0, 0}; size_t OCB = OC / 4; size_t ICB = IC / 4; @@ -225,33 +224,33 @@ void winograd_2x3_4x4_f::filter( //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 //! 0 0 1 0 0 0 0 0 0 0 0 0 - float32x4_t vf0 = vld1q_f32(filter_ptr); - float32x4_t vf1 = vld1q_f32(filter_ptr + 4); - float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); - - float32x4_t v3(vdupq_n_f32(0)); - auto vtmp = vextq_f32(vf1, vf2, 2); - vtmp = vsetq_lane_f32(0, vtmp, 3); - float32x4_t v2(vtmp); - vtmp = vextq_f32(vf0, vf1, 3); - vtmp = vsetq_lane_f32(0, vtmp, 3); - float32x4_t v1(vtmp); - vtmp = vsetq_lane_f32(0, vf0, 3); - float32x4_t v0(vtmp); - - float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), - vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); + GI_FLOAT32_t vf0 = GiLoadFloat32(filter_ptr); + GI_FLOAT32_t vf1 = GiLoadFloat32(filter_ptr + 4); + GI_FLOAT32_t vf2 = GiBroadcastFloat32(filter_ptr[8]); + + GI_FLOAT32_t v3(GiBroadcastFloat32(0)); + auto vtmp = GiExtqFloat32(vf1, vf2, 2); + vtmp = GiSetqLaneFloat32(0, vtmp, 3); + GI_FLOAT32_t v2(vtmp); + vtmp = GiExtqFloat32(vf0, vf1, 3); + vtmp = GiSetqLaneFloat32(0, vtmp, 3); + GI_FLOAT32_t v1(vtmp); + vtmp = GiSetqLaneFloat32(0, vf0, 3); + GI_FLOAT32_t v0(vtmp); + + GI_FLOAT32_t vsum0 = GiBroadcastFloat32(0), vsum1 = GiBroadcastFloat32(0), + vsum2 = GiBroadcastFloat32(0), vsum3 = GiBroadcastFloat32(0); MATRIX_MUL4x4(vsum, g, v); - float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), - vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); + GI_FLOAT32_t vres0 = GiBroadcastFloat32(0), vres1 = GiBroadcastFloat32(0), + vres2 = GiBroadcastFloat32(0), vres3 = GiBroadcastFloat32(0); MATRIX_MUL4x4(vres, vsum, gt); - vst1q_f32(transform_mid_buf, vres0); - vst1q_f32(transform_mid_buf + 4, vres1); - vst1q_f32(transform_mid_buf + 8, vres2); - vst1q_f32(transform_mid_buf + 12, vres3); + GiStoreFloat32(transform_mid_buf, vres0); + GiStoreFloat32(transform_mid_buf + 4, vres1); + GiStoreFloat32(transform_mid_buf + 8, vres2); + GiStoreFloat32(transform_mid_buf + 12, vres3); size_t ocb = oc / 4; size_t oc4 = oc % 4; @@ -266,7 +265,7 @@ void winograd_2x3_4x4_f::filter( } } -void winograd_2x3_4x4_f::input( +void winograd_gi_2x3_4x4_f::input( const float* input, float* input_transform_buf, float* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { @@ -304,7 +303,7 @@ void winograd_2x3_4x4_f::input( } } -void winograd_2x3_4x4_f::output( +void winograd_gi_2x3_4x4_f::output( const float* output_transform_buf, const float* bias, float* output, float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, @@ -322,8 +321,8 @@ void winograd_2x3_4x4_f::output( auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F23, cb, float, float, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); @@ -333,7 +332,7 @@ void winograd_2x3_4x4_f::output( } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp similarity index 94% rename from dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp index 12093fd0..4556218a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F45) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { struct FilterTransform4X5 { @@ -126,9 +125,9 @@ struct FilterTransform4X5 { #undef cb FILTER_TRANSFORM(g, Gg) - float32x4x2_t vgr; - float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; - float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; + GI_FLOAT32_V2_t vgr; + GI_FLOAT32_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; + GI_FLOAT32_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; Vector Ggt4(vgr); @@ -167,8 +166,10 @@ struct InputTransform4X5 { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0) -#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) template static void transform( @@ -345,22 +346,22 @@ struct OutputTransform4X5 { #undef cb if (oh_start + 4 <= OH && ow_start + 4 <= OW) { - float32x4_t bias0; + GI_FLOAT32_t bias0; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - bias0 = vdupq_n_f32(bias[oc]); + bias0 = GiBroadcastFloat32(bias[oc]); } rep(i, 4) { size_t oh = oh_start + i; - float32x4_t item0 = vld1q_f32(mid_buf1); + GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - item0 = vaddq_f32(item0, bias0); + item0 = GiAddFloat32(item0, bias0); } else if (bmode == BiasMode::BIAS) { - bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); - item0 = vaddq_f32(item0, bias0); + bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); + item0 = GiAddFloat32(item0, bias0); } item0 = op(item0); - vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); + GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); mid_buf1 += 4; } } else { @@ -388,7 +389,7 @@ struct OutputTransform4X5 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) @@ -448,8 +449,8 @@ void winograd_4x5_1x1_f::output( auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F45, cb, float, float, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); @@ -459,7 +460,7 @@ void winograd_4x5_1x1_f::output( } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp similarity index 94% rename from dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp index dac78b2e..77bd62c8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F54) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { struct FilterTransform5X4 { @@ -94,7 +93,6 @@ struct FilterTransform5X4 { transform_mid_buf[j * alpha + i]; } #else - #define cb(i) \ do { \ mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ @@ -117,7 +115,7 @@ struct FilterTransform5X4 { mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ mid_buf1 += 8; \ } while (0); -#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) +#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) float* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); @@ -154,8 +152,10 @@ struct InputTransform5X4 { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0) -#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) template static void transform( @@ -348,29 +348,29 @@ struct OutputTransform5X4 { #undef cb if (oh_start + 5 <= OH && ow_start + 5 <= OW) { - float32x4_t bias0; + GI_FLOAT32_t bias0; float32_t bias1; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - bias0 = vdupq_n_f32(bias[oc]); + bias0 = GiBroadcastFloat32(bias[oc]); bias1 = bias[oc]; } rep(i, 5) { size_t oh = oh_start + i; - float32x4_t item0 = vld1q_f32(mid_buf1); + GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); float32_t item1 = mid_buf1[4]; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - item0 = vaddq_f32(item0, bias0); + item0 = GiAddFloat32(item0, bias0); item1 = item1 + bias1; } else if (bmode == BiasMode::BIAS) { - bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); + bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; - item0 = vaddq_f32(item0, bias0); + item0 = GiAddFloat32(item0, bias0); item1 = item1 + bias1; } item0 = op(item0); item1 = op(item1); - vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); + GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); output[oc * OH * OW + oh * OW + ow_start + 4] = item1; mid_buf1 += 5; @@ -400,7 +400,7 @@ struct OutputTransform5X4 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) @@ -461,8 +461,8 @@ void winograd_5x4_1x1_f::output( auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F54, cb, float, float, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); @@ -472,7 +472,7 @@ void winograd_5x4_1x1_f::output( } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp similarity index 90% rename from dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp index f0d958d5..387794ed 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/filter_transform.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { /** @@ -57,8 +56,10 @@ namespace { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0); -#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) +#define GET_VECTOR_LOW_ELEM(s, i, idx) \ + GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) struct InputTransform6X3 { template static void transform( @@ -271,31 +272,31 @@ struct OutputTransform6X3 { #undef cb if (oh_start + 6 <= OH && ow_start + 6 <= OW) { - float32x4_t bias0; + GI_FLOAT32_t bias0; float32x2_t bias1; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - bias0 = vdupq_n_f32(bias[oc]); - bias1 = vdup_n_f32(bias[oc]); + bias0 = GiBroadcastFloat32(bias[oc]); + bias1 = GiDupFloat32(bias[oc]); } rep(i, 6) { size_t oh = oh_start + i; - float32x4_t item0 = vld1q_f32(mid_buf1); - float32x2_t item1 = vld1_f32(mid_buf1 + 4); + GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); + float32x2_t item1 = GiLdFloat32(mid_buf1 + 4); if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - item0 = vaddq_f32(item0, bias0); - item1 = vadd_f32(item1, bias1); + item0 = GiAddFloat32(item0, bias0); + item1 = GiAddDFloat32(item1, bias1); } else if (bmode == BiasMode::BIAS) { - bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); - bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + 4); - item0 = vaddq_f32(item0, bias0); - item1 = vadd_f32(item1, bias1); + bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); + bias1 = GiLdFloat32(bias + oc * OH * OW + oh * OW + ow_start + 4); + item0 = GiAddFloat32(item0, bias0); + item1 = GiAddDFloat32(item1, bias1); } item0 = op(item0); - item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); - item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); - vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); - vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); + item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 0)), item1, 0); + item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 1)), item1, 1); + GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); + GiSt1Float32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); mid_buf1 += 6; } @@ -325,7 +326,7 @@ struct OutputTransform6X3 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) @@ -385,8 +386,8 @@ void winograd_6x3_1x1_f::output( auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F63, cb, float, float, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); @@ -396,7 +397,7 @@ void winograd_6x3_1x1_f::output( } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp similarity index 93% rename from dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp index 15bba172..9352235e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/filter_transform.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_4x4) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { @@ -41,16 +40,16 @@ struct InputTransform6X3 { const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 4; ++ico) { if (ic + ico < IC) { -#define cb(i) \ - auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ - auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); +#define cb(i) \ + auto v##i##0 = GiLoadFloat32(input_ptr + IW * i); \ + auto v##i##1 = GiLoadFloat32(input_ptr + IW * i + 4); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) \ - vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ - vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); +#define cb(i) \ + GiStoreFloat32(patch + ico * 8 * alpha + i * 8, v##i##0); \ + GiStoreFloat32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -255,7 +254,7 @@ struct OutputTransform6X3 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) @@ -323,8 +322,8 @@ void winograd_6x3_4x4_f::output( auto nw = index % units_w; size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F63_4x4, cb, float, float, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); @@ -334,7 +333,7 @@ void winograd_6x3_4x4_f::output( } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp similarity index 94% rename from dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp index f2c13301..fb165dce 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4) +MIDOUT_DECL(megdnn_fallback_winograd_nchw44_fp32_F23_mk4) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { constexpr size_t alpha = 2 + 3 - 1; @@ -72,8 +71,9 @@ struct InputTransformF23_NCHW44 { for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; - auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32(patchT + iho * alpha * pack_size + iwo * pack_size, src); + auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); + GiStoreFloat32( + patchT + iho * alpha * pack_size + iwo * pack_size, src); } } #define cb(m, n) \ @@ -190,7 +190,7 @@ struct OutputTransformF23_NCHW44 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) @@ -313,14 +313,14 @@ void winograd_F23_mk4_f_nchw44::output( OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, nonline_mode); #undef cb } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp similarity index 75% rename from dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp index dd0e1e9c..b8a46c35 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/filter_transform.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_mk4) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_mk4) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { @@ -49,11 +48,11 @@ struct InputTransformF63_NCHW44 { const float* input_ptr = input + icb * IH * IW4 + ih_start * IW4 + iw4_start; for (size_t ih = 0; ih < alpha; ih++) { -#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); +#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); +#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb input_ptr += IW4; @@ -68,8 +67,9 @@ struct InputTransformF63_NCHW44 { for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; - auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); + auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); + GiStoreFloat32( + patchT + iho * pack_size * alpha + iwo * pack_size, src); } } } @@ -83,10 +83,10 @@ struct InputTransformF63_NCHW44 { size_t ICB = IC / pack_size; size_t icb = ic / pack_size; - float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; - float32x4_t v0 = vld1q_f32(input_parameters + 0); - float32x4_t v1 = vld1q_f32(input_parameters + 4); - float32x4_t v2 = vld1q_f32(input_parameters + 8); + GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7; + GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); + GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); + GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); //! B //! 1 0 0 0 0 0 0 0 @@ -98,57 +98,57 @@ struct InputTransformF63_NCHW44 { //! -1 1 1 1 1 1 1 0 //! 0 0 0 0 0 0 0 1 -#define cb(i) \ - d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ - d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ - d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ - d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ - d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ - d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ - auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ - auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ - auto t##i##1 = d6; \ - auto t##i##2 = d6; \ - auto t##i##3 = d6; \ - auto t##i##4 = d6; \ - auto t##i##5 = d6; \ - auto t##i##6 = d6; \ - t##i##0 = t##i##0 - d6; \ - t##i##1 = t##i##1 + d1; \ - t##i##2 = t##i##2 - d1; \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ - t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ - t##i##7 = t##i##7 - d1; \ - t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ - t##i##1 = t##i##1 + d2; \ - t##i##2 = t##i##2 + d2; \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ - t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ - t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ - t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ - t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ - t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ - t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ - t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ - t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ - t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ - t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ - t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ - t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ - t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ - t##i##1 = t##i##1 + d5; \ - t##i##2 = t##i##2 - d5; \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ - t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ - t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); +#define cb(i) \ + d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ + d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ + d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ + d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ + d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ + d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ + auto t##i##0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ + auto t##i##7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ + auto t##i##1 = d6; \ + auto t##i##2 = d6; \ + auto t##i##3 = d6; \ + auto t##i##4 = d6; \ + auto t##i##5 = d6; \ + auto t##i##6 = d6; \ + t##i##0 = t##i##0 - d6; \ + t##i##1 = t##i##1 + d1; \ + t##i##2 = t##i##2 - d1; \ + t##i##3 = GiSimdFmaLane(t##i##3, d1, v0, 2); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d1, v0, 2); \ + t##i##5 = GiSimdFmaLane(t##i##5, d1, v1, 2); \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d1, v1, 2); \ + t##i##7 = t##i##7 - d1; \ + t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 0); \ + t##i##1 = t##i##1 + d2; \ + t##i##2 = t##i##2 + d2; \ + t##i##3 = GiSimdFmaLane(t##i##3, d2, v0, 3); \ + t##i##4 = GiSimdFmaLane(t##i##4, d2, v0, 3); \ + t##i##5 = GiSimdFmaLane(t##i##5, d2, v1, 3); \ + t##i##6 = GiSimdFmaLane(t##i##6, d2, v1, 3); \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d3, v0, 1); \ + t##i##2 = GiSimdFmaLane(t##i##2, d3, v0, 1); \ + t##i##3 = GiFmsqLaneQFloat32(t##i##3, d3, v1, 0); \ + t##i##4 = GiSimdFmaLane(t##i##4, d3, v1, 0); \ + t##i##5 = GiFmsqLaneQFloat32(t##i##5, d3, v1, 0); \ + t##i##6 = GiSimdFmaLane(t##i##6, d3, v1, 0); \ + t##i##7 = GiSimdFmaLane(t##i##7, d3, v0, 0); \ + t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 0); \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d4, v0, 1); \ + t##i##2 = GiFmsqLaneQFloat32(t##i##2, d4, v0, 1); \ + t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v1, 1); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d4, v1, 1); \ + t##i##5 = GiFmsqLaneQFloat32(t##i##5, d4, v2, 0); \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d4, v2, 0); \ + t##i##1 = t##i##1 + d5; \ + t##i##2 = t##i##2 - d5; \ + t##i##3 = GiSimdFmaLane(t##i##3, d5, v1, 2); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d5, v1, 2); \ + t##i##5 = GiSimdFmaLane(t##i##5, d5, v0, 2); \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v0, 2); \ + t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v0, 0); UNROLL_CALL_RAW(8, cb); #undef cb @@ -164,75 +164,75 @@ struct InputTransformF63_NCHW44 { d0 = d0 - t6##i; \ d1 = d1 + t1##i; \ d2 = d2 - t1##i; \ - d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ - d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ - d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ - d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ + d3 = GiSimdFmaLane(d3, t1##i, v0, 2); \ + d4 = GiFmsqLaneQFloat32(d4, t1##i, v0, 2); \ + d5 = GiSimdFmaLane(d5, t1##i, v1, 2); \ + d6 = GiFmsqLaneQFloat32(d6, t1##i, v1, 2); \ d7 = d7 - t1##i; \ - d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ + d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 0); \ d1 = d1 + t2##i; \ d2 = d2 + t2##i; \ - d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ - d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ - d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ - d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ - d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ - d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ - d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ - d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ - d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ - d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ - d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ - d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ - d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ - d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ - d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ - d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ - d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ - d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ + d3 = GiSimdFmaLane(d3, t2##i, v0, 3); \ + d4 = GiSimdFmaLane(d4, t2##i, v0, 3); \ + d5 = GiSimdFmaLane(d5, t2##i, v1, 3); \ + d6 = GiSimdFmaLane(d6, t2##i, v1, 3); \ + d1 = GiFmsqLaneQFloat32(d1, t3##i, v0, 1); \ + d2 = GiSimdFmaLane(d2, t3##i, v0, 1); \ + d3 = GiFmsqLaneQFloat32(d3, t3##i, v1, 0); \ + d4 = GiSimdFmaLane(d4, t3##i, v1, 0); \ + d5 = GiFmsqLaneQFloat32(d5, t3##i, v1, 0); \ + d6 = GiSimdFmaLane(d6, t3##i, v1, 0); \ + d7 = GiSimdFmaLane(d7, t3##i, v0, 0); \ + d0 = GiSimdFmaLane(d0, t4##i, v0, 0); \ + d1 = GiFmsqLaneQFloat32(d1, t4##i, v0, 1); \ + d2 = GiFmsqLaneQFloat32(d2, t4##i, v0, 1); \ + d3 = GiFmsqLaneQFloat32(d3, t4##i, v1, 1); \ + d4 = GiFmsqLaneQFloat32(d4, t4##i, v1, 1); \ + d5 = GiFmsqLaneQFloat32(d5, t4##i, v2, 0); \ + d6 = GiFmsqLaneQFloat32(d6, t4##i, v2, 0); \ d1 = d1 + t5##i; \ d2 = d2 - t5##i; \ - d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ - d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ - d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ - d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ - d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ - vst1q_f32( \ + d3 = GiSimdFmaLane(d3, t5##i, v1, 2); \ + d4 = GiFmsqLaneQFloat32(d4, t5##i, v1, 2); \ + d5 = GiSimdFmaLane(d5, t5##i, v0, 2); \ + d6 = GiFmsqLaneQFloat32(d6, t5##i, v0, 2); \ + d7 = GiFmsqLaneQFloat32(d7, t5##i, v0, 0); \ + GiStoreFloat32( \ input_transform_buf + \ (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d0); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d1); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d2); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d3); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d4); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d5); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d6); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ @@ -347,7 +347,7 @@ struct OutputTransformF63_NCHW44 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) @@ -488,14 +488,14 @@ void winograd_F63_mk4_f_nchw44::output( OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F63_mk4, cb, float, float, bmode, nonline_mode); #undef cb } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp similarity index 69% rename from dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp rename to dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp index d009a90d..b814794a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp + * \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,22 +9,21 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/conv_bias/fp32/filter_transform.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "src/arm_common/conv_bias/fp32/strategy.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" +#include "src/fallback/conv_bias/gi/fp32/helper.h" +#include "src/fallback/conv_bias/gi/fp32/strategy.h" +#include "src/fallback/conv_bias/gi/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/fallback/elemwise_helper/op_unary.h" #include "midout.h" -MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4) +MIDOUT_DECL(megdnn_fallback_winograd_fp32_F73_mk4) using namespace megdnn; -using namespace arm_common; +using namespace fallback; namespace { @@ -51,11 +50,11 @@ struct InputTransformF73_NCHW44 { const float* input_ptr = input + icb * IH * IW4 + ih_start * IW4 + iw4_start; for (size_t ih = 0; ih < alpha; ih++) { -#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); +#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); UNROLL_CALL_NOWRAPPER(9, cb); #undef cb -#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); +#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); UNROLL_CALL_NOWRAPPER(9, cb); #undef cb input_ptr += IW4; @@ -70,8 +69,9 @@ struct InputTransformF73_NCHW44 { for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; - auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); + auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); + GiStoreFloat32( + patchT + iho * pack_size * alpha + iwo * pack_size, src); } } } @@ -85,14 +85,14 @@ struct InputTransformF73_NCHW44 { size_t ICB = IC / pack_size; size_t icb = ic / pack_size; - float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8; - float32x4_t v0 = vld1q_f32(input_parameters + 0); - float32x4_t v1 = vld1q_f32(input_parameters + 4); - float32x4_t v2 = vld1q_f32(input_parameters + 8); - float32x4_t v3 = vld1q_f32(input_parameters + 12); - float32x4_t v4 = vld1q_f32(input_parameters + 16); - float32x4_t v5 = vld1q_f32(input_parameters + 20); - float32x4_t v6 = vld1q_f32(input_parameters + 24); + GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7, d8; + GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); + GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); + GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); + GI_FLOAT32_t v3 = GiLoadFloat32(input_parameters + 12); + GI_FLOAT32_t v4 = GiLoadFloat32(input_parameters + 16); + GI_FLOAT32_t v5 = GiLoadFloat32(input_parameters + 20); + GI_FLOAT32_t v6 = GiLoadFloat32(input_parameters + 24); //! B //! 1.5 0 0 0 0 0 0 0 0 @@ -113,77 +113,77 @@ struct InputTransformF73_NCHW44 { // 5.0f, 10.0f, 5.75f, 2.75f, v5 // 4.25f, 1.75f, 2.0f, 0.0f, v6 -#define cb(i) \ - d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ - d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ - d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ - d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ - d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ - d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ - d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ - d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ - auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \ - auto t##i##0 = d7; \ - auto t##i##1 = d7; \ - auto t##i##2 = d7; \ - auto t##i##3 = d7; \ - auto t##i##4 = d7; \ - auto t##i##5 = d7; \ - auto t##i##6 = d7; \ - auto t##i##7 = d7; \ - t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \ - t##i##0 = t##i##0 - d1; \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \ - t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \ - t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \ - t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \ - t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \ - t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \ - t##i##7 = t##i##7 - d1; \ - t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \ - t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \ - t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \ - t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \ - t##i##8 = t##i##8 - d2; \ - t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \ - t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \ - t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \ - t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \ - t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \ - t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \ - t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \ - t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \ - t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \ - t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \ - t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \ - t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \ - t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \ - t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \ - t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \ - t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \ - t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \ - t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \ - t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \ - t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \ - t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \ - t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \ - t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \ - t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \ - t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \ - t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \ - t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \ - t##i##5 = t##i##5 - d6; \ - t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \ - t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \ - t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0); +#define cb(i) \ + d0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ + d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ + d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ + d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ + d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ + d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ + d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ + d7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ + auto t##i##8 = GiLoadFloat32(patchT + i * alpha * pack_size + 8 * pack_size); \ + auto t##i##0 = d7; \ + auto t##i##1 = d7; \ + auto t##i##2 = d7; \ + auto t##i##3 = d7; \ + auto t##i##4 = d7; \ + auto t##i##5 = d7; \ + auto t##i##6 = d7; \ + auto t##i##7 = d7; \ + t##i##8 = GiFmsqLaneQFloat32(t##i##8, d7, v0, 0); \ + t##i##0 = t##i##0 - d1; \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d1, v0, 0); \ + t##i##2 = GiSimdFmaLane(t##i##2, d1, v0, 0); \ + t##i##3 = GiFmsqLaneQFloat32(t##i##3, d1, v0, 1); \ + t##i##4 = GiSimdFmaLane(t##i##4, d1, v0, 1); \ + t##i##5 = GiFmsqLaneQFloat32(t##i##5, d1, v0, 2); \ + t##i##6 = GiSimdFmaLane(t##i##6, d1, v0, 2); \ + t##i##7 = t##i##7 - d1; \ + t##i##8 = GiSimdFmaLane(t##i##8, d1, v0, 0); \ + t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 3); \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d2, v1, 0); \ + t##i##2 = GiFmsqLaneQFloat32(t##i##2, d2, v1, 1); \ + t##i##3 = GiSimdFmaLane(t##i##3, d2, v1, 2); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d2, v1, 3); \ + t##i##5 = GiFmsqLaneQFloat32(t##i##5, d2, v2, 0); \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d2, v2, 1); \ + t##i##8 = t##i##8 - d2; \ + t##i##0 = GiSimdFmaLane(t##i##0, d3, v2, 2); \ + t##i##1 = GiSimdFmaLane(t##i##1, d3, v2, 3); \ + t##i##2 = GiFmsqLaneQFloat32(t##i##2, d3, v3, 0); \ + t##i##3 = GiSimdFmaLane(t##i##3, d3, v2, 0); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d3, v3, 1); \ + t##i##5 = GiSimdFmaLane(t##i##5, d3, v3, 2); \ + t##i##6 = GiSimdFmaLane(t##i##6, d3, v3, 3); \ + t##i##7 = GiSimdFmaLane(t##i##7, d3, v2, 2); \ + t##i##8 = GiFmsqLaneQFloat32(t##i##8, d3, v0, 3); \ + t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 3); \ + t##i##1 = GiSimdFmaLane(t##i##1, d4, v4, 0); \ + t##i##2 = GiSimdFmaLane(t##i##2, d4, v4, 1); \ + t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v4, 2); \ + t##i##4 = GiSimdFmaLane(t##i##4, d4, v4, 3); \ + t##i##5 = GiSimdFmaLane(t##i##5, d4, v5, 0); \ + t##i##6 = GiSimdFmaLane(t##i##6, d4, v5, 1); \ + t##i##8 = GiSimdFmaLane(t##i##8, d4, v2, 2); \ + t##i##0 = GiFmsqLaneQFloat32(t##i##0, d5, v2, 2); \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d5, v5, 2); \ + t##i##2 = GiFmsqLaneQFloat32(t##i##2, d5, v5, 3); \ + t##i##3 = GiFmsqLaneQFloat32(t##i##3, d5, v6, 0); \ + t##i##4 = GiSimdFmaLane(t##i##4, d5, v6, 1); \ + t##i##5 = GiFmsqLaneQFloat32(t##i##5, d5, v5, 2); \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v6, 0); \ + t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v2, 2); \ + t##i##8 = GiSimdFmaLane(t##i##8, d5, v0, 3); \ + t##i##0 = GiFmsqLaneQFloat32(t##i##0, d6, v0, 0); \ + t##i##1 = GiFmsqLaneQFloat32(t##i##1, d6, v1, 0); \ + t##i##2 = GiFmsqLaneQFloat32(t##i##2, d6, v1, 1); \ + t##i##3 = GiSimdFmaLane(t##i##3, d6, v1, 0); \ + t##i##4 = GiFmsqLaneQFloat32(t##i##4, d6, v3, 1); \ + t##i##5 = t##i##5 - d6; \ + t##i##6 = GiFmsqLaneQFloat32(t##i##6, d6, v6, 2); \ + t##i##8 = GiFmsqLaneQFloat32(t##i##8, d6, v2, 2); \ + t##i##0 = GiSimdFmaLane(t##i##0, d0, v0, 0); UNROLL_CALL_RAW(9, cb); #undef cb @@ -198,100 +198,100 @@ struct InputTransformF73_NCHW44 { d5 = t7##i; \ d6 = t7##i; \ d7 = t7##i; \ - d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ + d8 = GiFmsqLaneQFloat32(d8, t7##i, v0, 0); \ d0 = d0 - t1##i; \ - d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ - d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ - d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ - d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ - d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ - d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ + d1 = GiFmsqLaneQFloat32(d1, t1##i, v0, 0); \ + d2 = GiSimdFmaLane(d2, t1##i, v0, 0); \ + d3 = GiFmsqLaneQFloat32(d3, t1##i, v0, 1); \ + d4 = GiSimdFmaLane(d4, t1##i, v0, 1); \ + d5 = GiFmsqLaneQFloat32(d5, t1##i, v0, 2); \ + d6 = GiSimdFmaLane(d6, t1##i, v0, 2); \ d7 = d7 - t1##i; \ - d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ - d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ - d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ - d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ - d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ - d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ - d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ - d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ + d8 = GiSimdFmaLane(d8, t1##i, v0, 0); \ + d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 3); \ + d1 = GiFmsqLaneQFloat32(d1, t2##i, v1, 0); \ + d2 = GiFmsqLaneQFloat32(d2, t2##i, v1, 1); \ + d3 = GiSimdFmaLane(d3, t2##i, v1, 2); \ + d4 = GiFmsqLaneQFloat32(d4, t2##i, v1, 3); \ + d5 = GiFmsqLaneQFloat32(d5, t2##i, v2, 0); \ + d6 = GiFmsqLaneQFloat32(d6, t2##i, v2, 1); \ d8 = d8 - t2##i; \ - d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ - d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ - d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ - d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ - d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ - d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ - d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ - d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ - d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ - d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ - d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ - d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ - d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ - d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ - d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ - d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ - d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ - d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ - d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ - d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ - d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ - d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ - d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ - d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ - d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ - d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ - d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ - d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ - d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ - d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ - d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ + d0 = GiSimdFmaLane(d0, t3##i, v2, 2); \ + d1 = GiSimdFmaLane(d1, t3##i, v2, 3); \ + d2 = GiFmsqLaneQFloat32(d2, t3##i, v3, 0); \ + d3 = GiSimdFmaLane(d3, t3##i, v2, 0); \ + d4 = GiFmsqLaneQFloat32(d4, t3##i, v3, 1); \ + d5 = GiSimdFmaLane(d5, t3##i, v3, 2); \ + d6 = GiSimdFmaLane(d6, t3##i, v3, 3); \ + d7 = GiSimdFmaLane(d7, t3##i, v2, 2); \ + d8 = GiFmsqLaneQFloat32(d8, t3##i, v0, 3); \ + d0 = GiSimdFmaLane(d0, t4##i, v0, 3); \ + d1 = GiSimdFmaLane(d1, t4##i, v4, 0); \ + d2 = GiSimdFmaLane(d2, t4##i, v4, 1); \ + d3 = GiFmsqLaneQFloat32(d3, t4##i, v4, 2); \ + d4 = GiSimdFmaLane(d4, t4##i, v4, 3); \ + d5 = GiSimdFmaLane(d5, t4##i, v5, 0); \ + d6 = GiSimdFmaLane(d6, t4##i, v5, 1); \ + d8 = GiSimdFmaLane(d8, t4##i, v2, 2); \ + d0 = GiFmsqLaneQFloat32(d0, t5##i, v2, 2); \ + d1 = GiFmsqLaneQFloat32(d1, t5##i, v5, 2); \ + d2 = GiFmsqLaneQFloat32(d2, t5##i, v5, 3); \ + d3 = GiFmsqLaneQFloat32(d3, t5##i, v6, 0); \ + d4 = GiSimdFmaLane(d4, t5##i, v6, 1); \ + d5 = GiFmsqLaneQFloat32(d5, t5##i, v5, 2); \ + d6 = GiFmsqLaneQFloat32(d6, t5##i, v6, 0); \ + d7 = GiFmsqLaneQFloat32(d7, t5##i, v2, 2); \ + d8 = GiSimdFmaLane(d8, t5##i, v0, 3); \ + d0 = GiFmsqLaneQFloat32(d0, t6##i, v0, 0); \ + d1 = GiFmsqLaneQFloat32(d1, t6##i, v1, 0); \ + d2 = GiFmsqLaneQFloat32(d2, t6##i, v1, 1); \ + d3 = GiSimdFmaLane(d3, t6##i, v1, 0); \ + d4 = GiFmsqLaneQFloat32(d4, t6##i, v3, 1); \ d5 = d5 - t6##i; \ - d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ - d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ - d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ - vst1q_f32( \ + d6 = GiFmsqLaneQFloat32(d6, t6##i, v6, 2); \ + d8 = GiFmsqLaneQFloat32(d8, t6##i, v2, 2); \ + d0 = GiSimdFmaLane(d0, t0##i, v0, 0); \ + GiStoreFloat32( \ input_transform_buf + \ (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d0); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d1); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d2); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d3); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d4); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d5); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d6); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ d7); \ - vst1q_f32( \ + GiStoreFloat32( \ input_transform_buf + \ (8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ @@ -413,7 +413,7 @@ struct OutputTransformF73_NCHW44 { } // namespace namespace megdnn { -namespace arm_common { +namespace fallback { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) @@ -554,14 +554,14 @@ void winograd_F73_mk4_f_nchw44::output( OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F73_mk4, cb, float, float, bmode, + GI_DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_fallback_winograd_fp32_F73_mk4, cb, float, float, bmode, nonline_mode); #undef cb } } // namespace winograd -} // namespace arm_common +} // namespace fallback } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h b/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h new file mode 100644 index 00000000..0b1e9e83 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/intrinsic_helper.h @@ -0,0 +1,413 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/intrinsic_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/common/unroll_macro.h" +#include "src/fallback/conv_bias/common.h" +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" + +namespace megdnn { +namespace { + +struct Vld1qF32S { + static GI_FORCEINLINE GI_FLOAT32_t impl(const float32_t* ptr) { + return GiLoadFloat32(ptr); + } +}; + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" + +#ifdef __GNUC__ +#ifndef __has_warning +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#else +#if __has_warning("-Wmaybe-uninitialized") +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif +#endif +#endif + +template < + int weight_number, int base_offset, int ptr_step, int oc_block, typename Func, + typename T, typename T2, typename... XT> +struct LoadHelper { + static GI_FORCEINLINE void impl(T& weight, T2 ptr, int oc_offset, XT... args); +}; + +#define WEIGHT_CB(step) \ + src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); + +#define LOAD_HELPER(step) \ + template < \ + int base_offset, int ptr_step, typename Func, typename T, typename T2, \ + typename... XT> \ + struct LoadHelper { \ + static GI_FORCEINLINE void impl(T& src, T2 ptr, int, XT... args) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); +LOAD_HELPER(9); +LOAD_HELPER(10); +LOAD_HELPER(11); +LOAD_HELPER(12); +LOAD_HELPER(13); +LOAD_HELPER(14); +LOAD_HELPER(15); +LOAD_HELPER(16); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +///////////////////////////c_dim = 1///////////////////////// +#define WEIGHT_CB(step) src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static GI_FORCEINLINE void impl(T& src, T2 ptr, int) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); +LOAD_HELPER(9); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +/////////////////////////c_dim = 2/////////////////////////////// +#define WEIGHT_CB(step) \ + src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ + src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static GI_FORCEINLINE void impl(T& src, T2 ptr, int oc_offset) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +template < + int weight_number, int base_offset, int ptr_step, int c_dim, typename Func, + typename T, typename T2> +GI_FORCEINLINE void load_helper(T& weight, T2 ptr, int oc_offset) { + LoadHelper::impl( + weight, ptr, oc_offset); +} + +////////////////////Store_OCX_OW8_Remain///////////////////////// +template +struct StoreOcxOw8Remain { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); +}; + +template +struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + } +}; +template +struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + } +}; +template +struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[0][6], reinterpret_cast(dst_ptr + 24)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + } +}; +template +struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + } +}; +template +struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[0][4], reinterpret_cast(dst_ptr + 16)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + } +}; +template +struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; +template +struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); + + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + } +}; +template +struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; +template +struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + op(c[0][0], reinterpret_cast(dst_ptr)); + op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); + } +}; + +template +struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + } +}; +template +struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); + } +}; + +template +struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + op(c[0][6], reinterpret_cast(dst_ptr + 24)); + } +}; +template +struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); + } +}; +template +struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + op(c[0][4], reinterpret_cast(dst_ptr + 16)); + } +}; +template +struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); + } +}; +template +struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + op(c[0][2], reinterpret_cast(dst_ptr + 8)); + } +}; +template +struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); + } +}; +template +struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { + static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { + op(c[0][0], reinterpret_cast(dst_ptr)); + } +}; + +template +GI_FORCEINLINE void store_ocx_ow8_remain_static( + T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); +} + +#undef cb +#undef cb2 +#undef cb_case +#undef cb_case2 + +#pragma GCC diagnostic pop + +/////////////////////////init_ocx_ow8//////////////////// + +template +struct GiLdqSimd; +template <> +struct GiLdqSimd { + static constexpr int simd_len = 4; +}; +template +struct InitOcxOw8 { + static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step); +}; +template +struct InitOcxOw8 { + static GI_FORCEINLINE void impl(T&, const T2*, int) {} +}; + +#define BAIS_INIT_NO_BIAS_C2(step) \ + c[0][step] = GiBroadcastFloat32(static_cast(0)); \ + c[1][step] = GiBroadcastFloat32(static_cast(0)); +#define BAIS_INIT_NO_BIAS_C1(step) c[0][step] = GiBroadcastFloat32(static_cast(0)); + +#define BAIS_INIT_BROADCAST_C2(step) \ + c[0][step] = GiLoadFloat32(bias_ptr); \ + c[1][step] = GiLoadFloat32(bias_ptr + oc_step); +#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = GiLoadFloat32(bias_ptr); + +#define BAIS_INIT_BIAS_C2(step) \ + c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); \ + c[1][step] = GiLoadFloat32(bias_ptr + oc_step + step * simd_len); + +#define BAIS_INIT_BIAS_C1(step) c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); + +#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ + template \ + struct InitOcxOw8 { \ + static GI_FORCEINLINE void impl(T& c, const T2*, int) { \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ + constexpr int simd_len = GiLdqSimd::simd_len; \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ + } \ + }; +#define INSTANCE_InitOcxOw8_C(ow_remain) \ + INSTANCE_InitOcxOw8(ow_remain, 2); \ + INSTANCE_InitOcxOw8(ow_remain, 1); + +INSTANCE_InitOcxOw8_C(1); +INSTANCE_InitOcxOw8_C(2); +INSTANCE_InitOcxOw8_C(3); +INSTANCE_InitOcxOw8_C(4); +INSTANCE_InitOcxOw8_C(5); +INSTANCE_InitOcxOw8_C(6); +INSTANCE_InitOcxOw8_C(7); +INSTANCE_InitOcxOw8_C(8); + +#undef INSTANCE_InitOcxOw8 +#undef INSTANCE_InitOcxOw8_C +#undef BAIS_INIT_BIAS_C1 +#undef BAIS_INIT_BIAS_C2 +#undef BAIS_INIT_BROADCAST_C1 +#undef BAIS_INIT_BROADCAST_C2 +#undef BAIS_INIT_NO_BIAS_C1 +#undef BAIS_INIT_NO_BIAS_C2 + +template +GI_FORCEINLINE void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { + InitOcxOw8::impl(c, bias_ptr, oc_step); +} + +} // namespace +} // namespace megdnn +#undef GI_FORCEINLINE +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/gi/postprocess_helper.h b/dnn/src/fallback/conv_bias/gi/postprocess_helper.h new file mode 100644 index 00000000..645adb27 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/postprocess_helper.h @@ -0,0 +1,86 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/postprocess_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megdnn/basic_types.h" +#include "src/fallback/conv_bias/opr_impl.h" + +#include "midout.h" + +MIDOUT_DECL(fallback_gi_conv_bias_postprocess_helper) + +namespace { + +#define GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_nonline_mode) { \ + case param::ConvBias::NonlineMode::IDENTITY: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ + cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::RELU: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ + cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::SIGMOID: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ + cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::H_SWISH: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ + cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +#define GI_DISPATCH_CONV_WINOGRAD_BIAS( \ + _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_bmode) { \ + case BiasMode::BIAS: { \ + GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::NO_BIAS: { \ + GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::BROADCAST_CHANNEL_BIAS: { \ + GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 2, _src_type, _dst_type, \ + BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ + } + +} // namespace diff --git a/dnn/src/fallback/conv_bias/gi/utils.h b/dnn/src/fallback/conv_bias/gi/utils.h new file mode 100644 index 00000000..504b5697 --- /dev/null +++ b/dnn/src/fallback/conv_bias/gi/utils.h @@ -0,0 +1,193 @@ +/** + * \file dnn/src/fallback/conv_bias/gi/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "src/common/utils.h" +#include "src/fallback/general_intrinsic/gi_float.h" + +namespace megdnn { +namespace fallback { + +template +struct Vector; + +template <> +struct Vector { + GI_FLOAT32_t value; + Vector() {} + Vector(const float v) { value = GiBroadcastFloat32(v); } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const GI_FLOAT32_t& v) { value = v; } + static Vector load(const float* addr) { + Vector v; + v.value = GiLoadFloat32(addr); + return v; + } + static void save(float* addr, const Vector& v) { GiStoreFloat32(addr, v.value); } + void save(float* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value = GiAddFloat32(value, lr.value); + return dst; + } + Vector& operator+=(const Vector& lr) { + value = GiAddFloat32(value, lr.value); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value = GiSubtractFloat32(value, lr.value); + return dst; + } + Vector& operator-=(const Vector& lr) { + value = GiSubtractFloat32(value, lr.value); + return *this; + } + Vector operator*(float lr) { + Vector dst; + dst.value = GiMultiplyScalerFloat32(value, lr); + return dst; + } + Vector operator*(const Vector& lr) { + Vector dst; + dst.value = GiMultiplyFloat32(value, lr.value); + return dst; + } + Vector& operator*=(const Vector& lr) { + value = GiMultiplyFloat32(value, lr.value); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value = -value; + return dst; + } +}; + +template <> +struct Vector { + GI_FLOAT32_V2_t value; + Vector() {} + Vector(const float v) { + value.val[0] = GiBroadcastFloat32(v); + value.val[1] = GiBroadcastFloat32(v); + } + Vector(const Vector& lr) { value = lr.value; } + Vector(const Vector&& lr) { value = std::move(lr.value); } + Vector(const GI_FLOAT32_V2_t& v) { value = v; } + static Vector load(const float* addr) { + Vector v; +#if defined(GI_TEST_NAIVE) + v.value.val[0] = GiLoadFloat32(addr); + v.value.val[1] = GiLoadFloat32(addr + 4); +#elif defined(__arm__) || defined(__aarch64__) + v.value = vld1q_f32_x2(addr); +#else + v.value.val[0] = GiLoadFloat32(addr); + v.value.val[1] = GiLoadFloat32(addr + 4); +#endif + return v; + } + static void save(float* addr, const Vector& v) { +#if defined(GI_TEST_NAIVE) + GiStoreFloat32(addr, v.value.val[0]); + GiStoreFloat32(addr + 4, v.value.val[1]); +#elif defined(__arm__) || defined(__aarch64__) + vst1q_f32_x2(addr, v.value); +#else + GiStoreFloat32(addr, v.value.val[0]); + GiStoreFloat32(addr + 4, v.value.val[1]); +#endif + } + + void save(float* addr) { save(addr, *this); } + Vector operator+(const Vector& lr) { + Vector dst; + dst.value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); + dst.value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator+=(const Vector& lr) { + value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); + value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& add(const Vector& lr) { + value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); + value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator-(const Vector& lr) { + Vector dst; + dst.value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); + dst.value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator-=(const Vector& lr) { + value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); + value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); + return *this; + } + Vector operator*(float lr) { + Vector dst; + dst.value.val[0] = GiMultiplyScalerFloat32(value.val[0], lr); + dst.value.val[1] = GiMultiplyScalerFloat32(value.val[1], lr); + return dst; + } + //! val + lr * n + Vector& mla(const Vector& lr, float n) { + value.val[0] = GiMultiplyAddScalarFloat32(value.val[0], lr.value.val[0], n); + value.val[1] = GiMultiplyAddScalarFloat32(value.val[1], lr.value.val[1], n); + return *this; + } + + Vector operator*(const Vector& lr) { + Vector dst; + dst.value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); + dst.value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); + return dst; + } + Vector& operator*=(const Vector& lr) { + value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); + value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); + return *this; + } + Vector& operator=(const Vector& lr) { + value = lr.value; + return *this; + } + Vector& operator=(const Vector&& lr) { + value = std::move(lr.value); + return *this; + } + Vector operator-() { + Vector dst; + dst.value.val[0] = -value.val[0]; + dst.value.val[1] = -value.val[1]; + return dst; + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 25869115..19606065 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -16,6 +16,7 @@ #include "src/fallback/conv_bias/algos.h" #include "src/fallback/conv_bias/conv1x1/algos.h" #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" +#include "src/fallback/conv_bias/gi/fp32/algos.h" #include "src/fallback/conv_bias/im2col/algos.h" #include "src/fallback/convolution/opr_impl.h" #include "src/naive/convolution/algorithms.h" @@ -34,6 +35,14 @@ using namespace megdnn; using namespace fallback; +namespace { + +//! TODO: imp is_fallback_exclude_gi_or_naive +bool is_naive(const detail::Algorithm* algo) { + return algo->handle_type() == Handle::HandleType::NAIVE; +} +} // anonymous namespace + size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { switch (format) { case param::ConvBias::Format::NCHW44: @@ -73,16 +82,95 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { SmallVector> refhold; SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; + SmallVector m_gi_winograd_algos; + + AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; + AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; + AlgoF32DirectNCHW44 f32_direct_nchw44; + + AlgoF32Direct f32_direct; + AlgoF32DirectStride2 f32_direct_stride2; + AlgoF32DirectStride1 f32_direct_stride1; public: AlgoPack() { + // fallback gi fp32 algo + m_all_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); + m_all_algos.emplace_back(&f32_chanel_wise_nchw44); + m_all_algos.emplace_back(&f32_direct_nchw44); + m_all_algos.emplace_back(&f32_direct_stride1); + m_all_algos.emplace_back(&f32_direct_stride2); + m_all_algos.emplace_back(&f32_direct); + + static CpuOprDelegationStorage<2> storage; + auto matmul_opr = storage.get(); + using MatmulFormat = param::MatrixMul::Format; + auto&& matmul_algos = + static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4}); + for (auto&& algo : matmul_algos) { + if (is_naive(algo)) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { + refhold.emplace_back(new AlgoFP32WinogradF23_4x4( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF63_4x4( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); +//! uncomment this when low precision mode is done +#if 0 + refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); +#endif + } + } + + //! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw + //! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd + matmul_algos = static_cast(matmul_opr) + ->select_algo_type( + {AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); + for (auto&& algo : matmul_algos) { + if (is_naive(algo)) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { + refhold.emplace_back(new AlgoFP32WinogradF63( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF54( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF45( + static_cast(algo), + tile_size)); + m_gi_winograd_algos.emplace_back(refhold.back().get()); + } + } + for (auto&& algo : m_gi_winograd_algos) { + m_all_algos.emplace_back(algo); + } + // end fallback gi fp32 algo + refhold.emplace_back(new AlgoConv1x1Gemv()); m_all_algos.emplace_back(refhold.back().get()); - static CpuOprDelegationStorage<> storage; - auto matmul_opr = storage.get(); - auto&& matmul_algos = static_cast(matmul_opr) - ->get_all_packed_algo(); + matmul_algos = static_cast(matmul_opr) + ->get_all_packed_algo(); for (auto&& algo : matmul_algos) { #if MEGDNN_X86 //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index cf887d59..cac32b33 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -226,6 +226,20 @@ public: FB_CONV1x1, FB_CONV1x1_GEMV, FB_IM2COL, + GI_COMMON_WINOGRAD_F23_4X4_FP32, + GI_COMMON_WINOGRAD_F63_FP32, + GI_COMMON_WINOGRAD_F63_4X4_FP32, + GI_COMMON_WINOGRAD_F54_FP32, + GI_COMMON_WINOGRAD_F45_FP32, + GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, + GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, + GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, + GI_COMMON_DIRECT_FP32, + GI_COMMON_DIRECT_STRD1_FP32, + GI_COMMON_DIRECT_STRD2_FP32, + GI_COMMON_DIRECT_NCHW44_FP32, + GI_COMMON_DIRECT_NCHW_NCHW44_FP32, + GI_COMMON_CHWNWISE_NCHW44_F32, #if MEGDNN_X86 X86_DIRECT = 1 << 8, @@ -248,20 +262,6 @@ public: ARM_COMMON_DIRECT_STRD1_FP16, ARM_COMMON_CHWNWISE_NCHW88_F16, ARM_COMMON_DIRECT_NCHW88_FP16, - ARM_COMMON_WINOGRAD_F23_4X4_FP32, - ARM_COMMON_WINOGRAD_F63_FP32, - ARM_COMMON_WINOGRAD_F63_4X4_FP32, - ARM_COMMON_WINOGRAD_F54_FP32, - ARM_COMMON_WINOGRAD_F45_FP32, - ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, - ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, - ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, - ARM_COMMON_DIRECT_FP32, - ARM_COMMON_DIRECT_STRD1_FP32, - ARM_COMMON_DIRECT_STRD2_FP32, - ARM_COMMON_DIRECT_NCHW44_FP32, - ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, - ARM_COMMON_CHWNWISE_NCHW44_F32, ARM_COMMON_DIRECT_STRD1_S8, ARM_COMMON_DIRECT_STRD2_S8, ARM_COMMON_DIRECT_NCHW44, @@ -383,6 +383,23 @@ private: class AlgoWinogradF32_4x4; class AlgoWinogradQS8; class AlgoWinogradQS8_8x8; + + class AlgoFP32WinogradF23_4x4; + class AlgoFP32WinogradF63; + class AlgoFP32WinogradF63_4x4; + class AlgoFP32WinogradF54; + class AlgoFP32WinogradF45; + class AlgoFP32WinogradF23_4x4_NCHW44; + class AlgoFP32WinogradF63_4x4_NCHW44; + class AlgoFP32WinogradF73_4x4_NCHW44; + + class AlgoF32Direct; + class AlgoF32DirectStride1; + class AlgoF32DirectStride2; + class AlgoF32DirectNCHWNCHW44; + class AlgoF32ChannelWiseNCHW44; + class AlgoF32DirectNCHW44; + class AlgoPack; NCBKernSizeParam m_prev_selected_algo_sizep; diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index aced356c..80bcac22 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -81,23 +81,6 @@ TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { } } -TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker checker(handle()); - - check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); -} - -TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker> checker( - handle()); - - check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); -} - #define CONV_BIAS_MATMUL_QU8_MODE(MODE) \ using namespace conv_bias; \ std::vector args = get_quantized_args_with_nlmode(MODE); \ @@ -1015,14 +998,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { #endif } -TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_4x4) { -#if MEGDNN_AARCH64 - benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); -#else - benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); -#endif -} - TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { #if MEGDNN_AARCH64 benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); @@ -1031,14 +1006,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { #endif } -TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_4x4) { -#if MEGDNN_AARCH64 - benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); -#else - benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); -#endif -} - TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { #if MEGDNN_AARCH64 benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4); @@ -1212,30 +1179,10 @@ void benchmark_winograd_nchw_vs_nchw44( } } -TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { -#if MEGDNN_AARCH64 - benchmark_winograd_nchw_vs_nchw44( - "AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); -#else - benchmark_winograd_nchw_vs_nchw44( - "ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); -#endif -} - -TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { -#if MEGDNN_AARCH64 - benchmark_winograd_nchw_vs_nchw44( - "AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); -#else - benchmark_winograd_nchw_vs_nchw44( - "ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); -#endif -} - TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) { #if MEGDNN_AARCH64 benchmark_winograd_nchw_vs_nchw44( - "AARCH64_F32_MK4_4x16:4:6", "ARM_COMMON_F32_GEMV_MK4:4:7", handle()); + "AARCH64_F32_MK4_4x16:4:6", "FB_GI_F32_GEMV_MK4:4:7", handle()); #else benchmark_winograd_nchw_vs_nchw44( "ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:7", handle()); @@ -1609,156 +1556,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) { computations / used0, used1, computations / used1, used1 / used0); } } -TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { - // have to remove preferred restrict in usable func before run the benchmark - using namespace conv_bias; - param::ConvBias param; - param.stride_h = 1; - param.stride_w = 1; - param.pad_h = 1; - param.pad_w = 1; - param.nonlineMode = NonlineMode::RELU; - param.sparse = param::ConvBias::Sparse::GROUP; - - constexpr size_t RUN = 50; - Benchmarker benchmark0(handle()); - benchmark0.set_display(false); - benchmark0.set_param(param); - benchmark0.set_times(RUN); - benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("F32STRD1")); - - auto opr = handle()->create_operator(); - opr->param() = param; - - param.format = param::ConvBias::Format::NCHW44; - Benchmarker benchmark1(handle()); - benchmark1.set_display(false); - benchmark1.set_param(param); - benchmark1.set_times(RUN); - benchmark1.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("F32_CHANNEL_WISE_NCHW44")); - auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { - TensorLayout dst_layout; - opr->deduce_layout( - {{1, group * 4, h, w}, dtype::Int8()}, - {{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, - {{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); - //! dst.nr_elems * IC * FH * FW * 2 - float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / - (1024 * 1024 * 1024) * 1e3; - - auto used0 = benchmark0.exec( - {{1, group * 4, h, w}, - {group * 4, 1, 1, kernel, kernel}, - {1, group * 4, 1, 1}, - {}, - {}}) / - RUN; - auto used1 = benchmark1.exec( - {{1, group, h, w, 4}, - {group, 1, 1, kernel, kernel, 4}, - {1, group, 1, 1, 4}, - {}, - {}}) / - RUN; - printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " - "nchw44: " - "%f ms %f GFlops " - "speedup: %f\n", - group, h, w, kernel, used0, computations / used0, used1, - computations / used1, used0 / used1); - }; - for (size_t group : {8, 16, 32, 64}) { - for (size_t kerenl : {2, 3, 5}) { - run(group, 112, 112, kerenl); - run(group, 56, 56, kerenl); - run(group, 48, 48, kerenl); - run(group, 28, 28, kerenl); - run(group, 14, 14, kerenl); - } - } - run(8, 112, 112, 3); - run(32, 56, 56, 3); - run(64, 28, 28, 3); - run(128, 14, 14, 3); -} - -TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { - // have to remove preferred restrict in usable func before run the benchmark - using namespace conv_bias; - param::ConvBias param; - param.stride_h = 2; - param.stride_w = 2; - param.pad_h = 1; - param.pad_w = 1; - param.nonlineMode = NonlineMode::RELU; - param.sparse = param::ConvBias::Sparse::GROUP; - - constexpr size_t RUN = 50; - Benchmarker benchmark0(handle()); - benchmark0.set_display(false); - benchmark0.set_param(param); - benchmark0.set_times(RUN); - benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("F32STRD2")); - - auto opr = handle()->create_operator(); - opr->param() = param; - - param.format = param::ConvBias::Format::NCHW44; - Benchmarker benchmark1(handle()); - benchmark1.set_display(false); - benchmark1.set_param(param); - benchmark1.set_times(RUN); - benchmark1.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("F32_CHANNEL_WISE_NCHW44")); - auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { - TensorLayout dst_layout; - opr->deduce_layout( - {{1, group * 4, h, w}, dtype::Int8()}, - {{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, - {{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); - //! dst.nr_elems * IC * FH * FW * 2 - float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / - (1024 * 1024 * 1024) * 1e3; - - auto used0 = benchmark0.exec( - {{1, group * 4, h, w}, - {group * 4, 1, 1, kernel, kernel}, - {1, group * 4, 1, 1}, - {}, - {}}) / - RUN; - auto used1 = benchmark1.exec( - {{1, group, h, w, 4}, - {group, 1, 1, kernel, kernel, 4}, - {1, group, 1, 1, 4}, - {}, - {}}) / - RUN; - printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " - "nchw44: " - "%f ms %f GFlops " - "speedup: %f\n", - group, h, w, kernel, used0, computations / used0, used1, - computations / used1, used0 / used1); - }; - for (size_t group : {8, 16, 32, 64}) { - for (size_t kerenl : {2, 3, 5}) { - run(group, 112, 112, kerenl); - run(group, 56, 56, kerenl); - run(group, 48, 48, kerenl); - run(group, 28, 28, kerenl); - run(group, 14, 14, kerenl); - } - } - run(8, 112, 112, 3); - run(32, 56, 56, 3); - run(64, 28, 28, 3); - run(128, 14, 14, 3); -} - TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { // have to remove preferred restrict in usable func before run the benchmark using namespace conv_bias; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 5e4068b9..96cd0814 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -303,84 +303,6 @@ void checker_conv_bias_int8x8x32_multi( } } -/**********************************F32 direct************************/ -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { - check_conv_bias( - get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), - "F32DIRECT"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { - //! k=7 s=1 - check_conv_bias( - get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), - handle(), "F32_CONV_NCHW44_DIRECT"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { - check_conv_bias( - get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), - handle(), "F32_CONV_NCHW44_DIRECT"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { - check_conv_bias( - get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), handle(), - "F32_CONV_NCHW44_DIRECT"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { - check_conv_bias( - get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), - handle(), "F32_CONV_NCHW44_DIRECT"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { - check_conv_bias( - get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(), - "F32STRD1"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { - check_conv_bias( - get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(), - "F32STRD2"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { - check_conv_bias( - get_nchw44_conv_bias_args( - {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, - true), - handle(), "F32_CONV_NCHW_NCHW44"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) { - check_conv_bias( - get_nchw44_conv_bias_args( - {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, - true), - handle(), "F32_CONV_NCHW_NCHW44"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { - check_conv_bias( - get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), - "F32_CHANNEL_WISE_NCHW44"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { - check_conv_bias( - get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), - "F32_CHANNEL_WISE_NCHW44"); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { - check_conv_bias( - get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), - "F32_CHANNEL_WISE_NCHW44"); -} - /**********************************F16 direct************************/ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { @@ -787,50 +709,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { #endif } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker checker(handle()); - - check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) { - using namespace conv_bias; - std::vector args = - get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); - Checker checker(handle()); - check_winograd( - "4:2:32", checker, args, param::MatrixMul::Format::MK4, - param::ConvBias::Format::NCHW44); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { - using namespace conv_bias; - std::vector args = get_winograd_args(3); - Checker checker(handle()); - - check_winograd("1:6:32", checker, args); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker checker(handle()); - - check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { - using namespace conv_bias; - std::vector args = - get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); - Checker checker(handle()); - check_winograd( - "4:6:16", checker, args, param::MatrixMul::Format::MK4, - param::ConvBias::Format::NCHW44); -} - //! uncomment it when low precision mode is ok #if 0 TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { @@ -853,22 +731,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCE } #endif -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { - using namespace conv_bias; - std::vector args = get_winograd_args(4); - Checker checker(handle()); - - check_winograd("1:5:32", checker, args); -} - -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { - using namespace conv_bias; - std::vector args = get_winograd_args(5); - Checker checker(handle()); - - check_winograd("1:4:32", checker, args); -} - TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { using namespace conv_bias; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 771102ed..c49473ba 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -81,207 +81,6 @@ void benchmark_impl( } } // namespace -TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { - constexpr size_t RUNS = 50; - - param::ConvBias param; - param.nonlineMode = param::ConvBias::NonlineMode::RELU; - param.pad_h = 1; - param.pad_w = 1; - param.stride_h = 1; - param.stride_w = 1; - param.sparse = param::ConvBias::Sparse::GROUP; - - std::vector, float>> shapes_and_computation; - auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, - size_t group) { - SmallVector shapes{ - {N, IC, H, W}, - {group, OC / group, IC / group, FS, FS}, - {1, OC, 1, 1}, - {}, - {N, OC, H, W}}; - TensorShape dst{N, OC, H, W}; - float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + - dst.total_nr_elems()) * - 1e-6; - shapes_and_computation.push_back(std::make_pair(shapes, computations)); - }; - - bench_case(1, 32, 32, 200, 200, 3, 4); - bench_case(1, 32, 32, 200, 200, 3, 32); - bench_case(1, 32, 32, 128, 128, 3, 4); - bench_case(1, 32, 32, 128, 128, 3, 32); - bench_case(1, 32, 32, 100, 100, 3, 4); - bench_case(1, 32, 32, 100, 100, 3, 32); - bench_case(1, 32, 32, 80, 80, 3, 4); - bench_case(1, 32, 32, 80, 80, 3, 32); - - std::string algo_name = "F32DIRECT"; - printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); - std::vector data_type = { - dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); - shapes_and_computation.clear(); - - algo_name = "F32DIRECT"; - printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); - bench_case(1, 32, 32, 200, 200, 3, 1); - bench_case(1, 32, 32, 128, 128, 3, 1); - bench_case(1, 32, 32, 100, 100, 3, 1); - bench_case(1, 32, 32, 80, 80, 3, 1); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); -} -TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { - constexpr size_t RUNS = 50; - param::ConvBias param; - param.nonlineMode = param::ConvBias::NonlineMode::RELU; - param.pad_h = 1; - param.pad_w = 1; - param.stride_h = 1; - param.stride_w = 1; - param.sparse = param::ConvBias::Sparse::GROUP; - - std::vector, float>> shapes_and_computation; - auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, - size_t group) { - SmallVector shapes{ - {N, IC, H, W}, - {group, OC / group, IC / group, FS, FS}, - {1, OC, 1, 1}, - {}, - {N, OC, H, W}}; - TensorShape dst{N, OC, H, W}; - float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + - dst.total_nr_elems()) * - 1e-6; - shapes_and_computation.push_back(std::make_pair(shapes, computations)); - }; - - bench_case(1, 32, 32, 200, 200, 3, 4); - bench_case(1, 32, 32, 200, 200, 3, 32); - bench_case(1, 32, 32, 128, 128, 3, 4); - bench_case(1, 32, 32, 128, 128, 3, 32); - bench_case(1, 32, 32, 100, 100, 3, 4); - bench_case(1, 32, 32, 100, 100, 3, 32); - bench_case(1, 32, 32, 80, 80, 3, 4); - bench_case(1, 32, 32, 80, 80, 3, 32); - - std::string algo_name = "F32STRD1"; - printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); - std::vector data_type = { - dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); - shapes_and_computation.clear(); - - algo_name = "F32STRD1"; - printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); - bench_case(1, 32, 32, 200, 200, 3, 1); - bench_case(1, 32, 32, 128, 128, 3, 1); - bench_case(1, 32, 32, 100, 100, 3, 1); - bench_case(1, 32, 32, 80, 80, 3, 1); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); -} -TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { - constexpr size_t RUNS = 50; - - param::ConvBias param; - param.nonlineMode = param::ConvBias::NonlineMode::RELU; - param.pad_h = 1; - param.pad_w = 1; - param.stride_h = 2; - param.stride_w = 2; - param.sparse = param::ConvBias::Sparse::GROUP; - - std::vector, float>> shapes_and_computation; - auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, - size_t group, size_t P, size_t S) { - SmallVector shapes{ - {N, IC, H, W}, - {group, OC / group, IC / group, FS, FS}, - {1, OC, 1, 1}, - {}, - {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; - TensorShape dst{N, OC, H, W}; - float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + - dst.total_nr_elems()) * - 1e-6; - shapes_and_computation.push_back(std::make_pair(shapes, computations)); - }; - - bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); - bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); - bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); - bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); - bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); - bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); - bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); - bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); - - std::string algo_name = "F32STRD2"; - printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); - std::vector data_type = { - dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); - shapes_and_computation.clear(); - - algo_name = "F32STRD2"; - printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); - bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); - bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); - bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); - bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, - data_type); - benchmark_impl( - param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, - data_type); -} - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { constexpr size_t RUNS = 50; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp b/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp index 1e13ef74..c8091d20 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp @@ -20,91 +20,7 @@ using namespace megdnn; using namespace test; using namespace conv_bias; -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker> checker( - handle()); - check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = - get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); - Checker> checker( - handle()); - check_winograd( - "4:2:32", checker, args, param::MatrixMul::Format::MK4, - param::ConvBias::Format::NCHW44); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_args(3); - Checker> checker( - handle()); - check_winograd("1:6:32", checker, args); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_mk_packed_args(); - Checker> checker( - handle()); - - check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = - get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); - Checker> checker( - handle()); - check_winograd( - "4:6:16", checker, args, param::MatrixMul::Format::MK4, - param::ConvBias::Format::NCHW44); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_args(4); - Checker> checker( - handle()); - check_winograd("1:5:32", checker, args); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45_WEIGHT_PREPROCESS) { - using namespace conv_bias; - std::vector args = get_winograd_args(5); - Checker> checker( - handle()); - check_winograd("1:4:32", checker, args); -} -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { - using namespace conv_bias; - std::vector nchw44_args = - get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); - - Checker checker(handle()); - auto run = [&checker]( - const std::vector& args, DType A_dtype, DType B_dtype, - DType C_dtype, DType D_dtype, const float eps) { - for (auto&& arg : args) { - checker.set_dtype(0, A_dtype) - .set_dtype(1, B_dtype) - .set_dtype(2, C_dtype) - .set_dtype(4, D_dtype) - .set_epsilon(eps) - .set_param(arg.param) - .execs({arg.src, arg.filter, arg.bias, {}, {}}); - } - }; - - //! uncomment this when low precision mode is ok - // run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), - // dtype::Float32(), dtype::Float32(), 1e-2f); - - //! remove this when low precision mode is ok - run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), - dtype::Float32(), 1e-3f); -} TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1_WEIGHT_PREPROCESS) { using namespace conv_bias; diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index 8d4d2d0a..eab8e0cf 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -286,30 +286,6 @@ TEST_F(ARM_COMMON, FP32_GEVM) { run(M, K, N); } -TEST_F(ARM_COMMON, FP32_GEMV_MK4) { - Checker checker(handle()); - using Param = MatrixMul::Param; - - checker.set_before_exec_callback(AlgoChecker("ARM_COMMON_F32_GEMV_MK4")); - - checker.set_epsilon(1e-2); - auto run = [&](size_t M, size_t K) { - Param param; - param.format = param::MatrixMul::Format::MK4; - param.transposeA = false; - param.transposeB = false; - TensorShape A, B; - A = TensorShape{M / 4, K / 4, 4, 4}; - B = TensorShape{K / 4, 1, 4}; - checker.set_param(param).execs({A, B, {}}); - }; - - // N = 1 - for (size_t M : {4, 16, 128, 1024}) - for (size_t K : {4, 8, 12, 128, 256, 4096}) - run(M, K); -} - TEST_F(ARM_COMMON, MATRIX_MUL_RECORD) { TaskRecordChecker checker(0); checker.set_epsilon(1e-2); diff --git a/dnn/test/fallback/conv_bias.cpp b/dnn/test/fallback/conv_bias.cpp index 7b20108a..b28f287b 100644 --- a/dnn/test/fallback/conv_bias.cpp +++ b/dnn/test/fallback/conv_bias.cpp @@ -117,6 +117,30 @@ TEST_F(FALLBACK, CONV_BIAS_FORWARD_RECORD) { } } +TEST_F(FALLBACK, FP32_GEMV_MK4_GI) { + Checker checker(handle()); + using Param = MatrixMul::Param; + + checker.set_before_exec_callback(AlgoChecker("FB_GI_F32_GEMV_MK4")); + + checker.set_epsilon(1e-2); + auto run = [&](size_t M, size_t K) { + Param param; + param.format = param::MatrixMul::Format::MK4; + param.transposeA = false; + param.transposeB = false; + TensorShape A, B; + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, 1, 4}; + checker.set_param(param).execs({A, B, {}}); + }; + + // N = 1 + for (size_t M : {4, 16, 128, 1024}) + for (size_t K : {4, 8, 12, 128, 256, 4096}) + run(M, K); +} + std::vector get_conv_bias_args( std::vector kernel, std::vector padv, std::vector nlmodev, std::vector stridev, @@ -257,6 +281,189 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { dtype::Float32{}, dtype::Float32{}, "FALLBACK_NAIVE"); } +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S2) { + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args( + {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, + true), + handle(), "F32_CONV_NCHW_NCHW44"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S1) { + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args( + {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, + true), + handle(), "F32_CONV_NCHW_NCHW44"); +} + +std::vector get_nchw44_channel_wise_args( + std::vector kernel, size_t stride, bool no_bias, bool no_nonlinemode, + bool no_full_bias) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nlmode, bool pad) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + param.format = param::ConvBias::Format::NCHW44; + param.sparse = param::ConvBias::Sparse::GROUP; + + args.emplace_back( + param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{}); + if (!no_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{1, group, 1, 1, 4}); + } + if (!no_full_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w, 4}, + TensorShape{group, 1, 1, kernel, kernel, 4}, + TensorShape{ + n, group, (h + 2 * param.pad_w - kernel) / stride + 1, + (w + 2 * param.pad_w - kernel) / stride + 1, 4}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + for (size_t n : {1, 2}) { + for (auto nlmode : nonlinemode) { + for (bool pad : {true}) { + for (size_t group : {1, 2, 4, 7, 16}) { + for (size_t size : {4, 6, 7, 9, 20}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, pad); + } + } + } + } + for (bool pad : {false}) { + for (size_t group : {1, 2, 7, 16}) { + for (size_t size : {7, 9, 20}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, pad); + } + } + } + } + } + } + return args; +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { + check_conv_bias( + get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), + "F32_CHANNEL_WISE_NCHW44"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { + check_conv_bias( + get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), + "F32_CHANNEL_WISE_NCHW44"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { + check_conv_bias( + get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), + "F32_CHANNEL_WISE_NCHW44"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K7) { + //! k=7 s=1 + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args( + {7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), + handle(), "F32_CONV_NCHW44_DIRECT"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K2K3) { + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args( + {2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), + handle(), "F32_CONV_NCHW44_DIRECT"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K5) { + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), + handle(), "F32_CONV_NCHW44_DIRECT"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S2) { + check_conv_bias( + conv_bias::get_nchw44_conv_bias_args( + {2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), + handle(), "F32_CONV_NCHW44_DIRECT"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32) { + check_conv_bias( + conv_bias::get_conv_bias_args( + {1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), + handle(), "F32DIRECT"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR2) { + check_conv_bias( + conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), + handle(), "F32STRD2"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR1) { + check_conv_bias( + conv_bias::get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), + handle(), "F32STRD1"); +} + +TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_PREPROCESS_NCHW44) { + using namespace conv_bias; + std::vector nchw44_args = conv_bias::get_nchw44_conv_bias_args( + {3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); + + Checker checker(handle()); + + auto run = [&checker]( + const std::vector& args, DType A_dtype, DType B_dtype, + DType C_dtype, DType D_dtype, const float eps) { + for (auto&& arg : args) { + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + }; + + //! uncomment this when low precision mode is ok + // run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), + // dtype::Float32(), dtype::Float32(), 1e-2f); + + //! remove this when low precision mode is ok + run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), + dtype::Float32(), 1e-3f); +} TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { using namespace conv_bias; param::ConvBias cur_param; @@ -273,6 +480,422 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { } #if MEGDNN_WITH_BENCHMARK +namespace { +void benchmark_impl( + const param::ConvBias param, + std::vector, float>>& shapes_and_computation, + const std::string algo_name, size_t RUNS, + TaskExecutorConfig&& multi_thread_config, + TaskExecutorConfig&& single_thread_config, std::vector& data_type) { + std::vector multi_thread_times, single_thread_times; + { + auto multi_thread_hanle = create_cpu_handle(0, true, &multi_thread_config); + auto benchmarker = Benchmarker(multi_thread_hanle.get()); + benchmarker.set_times(RUNS) + .set_display(false) + .set_param(param) + .set_dtype(0, data_type[0]) + .set_dtype(1, data_type[1]) + .set_dtype(2, data_type[2]) + .set_dtype(4, data_type[3]) + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name.c_str())); + for (auto shape : shapes_and_computation) { + multi_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); + } + } + { + auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); + auto benchmarker = Benchmarker(single_thread_handle.get()); + benchmarker.set_times(RUNS) + .set_display(false) + .set_param(param) + .set_dtype(0, data_type[0]) + .set_dtype(1, data_type[1]) + .set_dtype(2, data_type[2]) + .set_dtype(4, data_type[3]) + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name.c_str())); + for (auto shape : shapes_and_computation) { + single_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); + } + } + printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); + printf("core_ids:"); + for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { + printf("%zu ", multi_thread_config.affinity_core_set[i]); + } + printf(", Single thread core_id %zu\n", single_thread_config.affinity_core_set[0]); + for (size_t i = 0; i < shapes_and_computation.size(); i++) { + auto shapes = shapes_and_computation[i]; + printf("Bench case: "); + for (auto&& shape : shapes.first) { + printf("%s ", shape.to_string().c_str()); + } + float computations = shapes.second; + printf("%zu threads gflops: %f,\n single thread gflops: " + "%f. spead up = %f, speedup/cores=%f\n", + multi_thread_config.nr_thread, computations / multi_thread_times[i], + computations / single_thread_times[i], + single_thread_times[i] / multi_thread_times[i], + single_thread_times[i] / multi_thread_times[i] / + multi_thread_config.nr_thread); + } +} +} // namespace + +TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, + size_t group) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F32DIRECT"; + printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); + shapes_and_computation.clear(); + + algo_name = "F32DIRECT"; + printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); +} + +TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR1) { + constexpr size_t RUNS = 50; + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 1; + param.stride_w = 1; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, + size_t group) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, H, W}}; + TensorShape dst{N, OC, H, W}; + float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4); + bench_case(1, 32, 32, 200, 200, 3, 32); + bench_case(1, 32, 32, 128, 128, 3, 4); + bench_case(1, 32, 32, 128, 128, 3, 32); + bench_case(1, 32, 32, 100, 100, 3, 4); + bench_case(1, 32, 32, 100, 100, 3, 32); + bench_case(1, 32, 32, 80, 80, 3, 4); + bench_case(1, 32, 32, 80, 80, 3, 32); + + std::string algo_name = "F32STRD1"; + printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); + shapes_and_computation.clear(); + + algo_name = "F32STRD1"; + printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1); + bench_case(1, 32, 32, 128, 128, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); +} + +TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR2) { + constexpr size_t RUNS = 50; + + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = 1; + param.pad_w = 1; + param.stride_h = 2; + param.stride_w = 2; + param.sparse = param::ConvBias::Sparse::GROUP; + + std::vector, float>> shapes_and_computation; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, + size_t group, size_t P, size_t S) { + SmallVector shapes{ + {N, IC, H, W}, + {group, OC / group, IC / group, FS, FS}, + {1, OC, 1, 1}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; + TensorShape dst{N, OC, H, W}; + float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + shapes_and_computation.push_back(std::make_pair(shapes, computations)); + }; + + bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); + bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); + + std::string algo_name = "F32STRD2"; + printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); + std::vector data_type = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); + shapes_and_computation.clear(); + + algo_name = "F32STRD2"; + printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); + bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); + bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); + bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); + bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, + data_type); + benchmark_impl( + param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, + data_type); +} +TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE1_NCHW44) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + param::ConvBias param; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 1; + param.pad_w = 1; + param.nonlineMode = NonlineMode::RELU; + param.sparse = param::ConvBias::Sparse::GROUP; + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_display(false); + benchmark0.set_param(param); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("F32STRD1")); + + auto opr = handle()->create_operator(); + opr->param() = param; + + param.format = param::ConvBias::Format::NCHW44; + Benchmarker benchmark1(handle()); + benchmark1.set_display(false); + benchmark1.set_param(param); + benchmark1.set_times(RUN); + benchmark1.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("F32_CHANNEL_WISE_NCHW44")); + auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { + TensorLayout dst_layout; + opr->deduce_layout( + {{1, group * 4, h, w}, dtype::Int8()}, + {{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, + {{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.exec( + {{1, group * 4, h, w}, + {group * 4, 1, 1, kernel, kernel}, + {1, group * 4, 1, 1}, + {}, + {}}) / + RUN; + auto used1 = benchmark1.exec( + {{1, group, h, w, 4}, + {group, 1, 1, kernel, kernel, 4}, + {1, group, 1, 1, 4}, + {}, + {}}) / + RUN; + printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " + "nchw44: " + "%f ms %f GFlops " + "speedup: %f\n", + group, h, w, kernel, used0, computations / used0, used1, + computations / used1, used0 / used1); + }; + for (size_t group : {8, 16, 32, 64}) { + for (size_t kerenl : {2, 3, 5}) { + run(group, 112, 112, kerenl); + run(group, 56, 56, kerenl); + run(group, 48, 48, kerenl); + run(group, 28, 28, kerenl); + run(group, 14, 14, kerenl); + } + } + run(8, 112, 112, 3); + run(32, 56, 56, 3); + run(64, 28, 28, 3); + run(128, 14, 14, 3); +} + +TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE2_NCHW44) { + // have to remove preferred restrict in usable func before run the benchmark + using namespace conv_bias; + param::ConvBias param; + param.stride_h = 2; + param.stride_w = 2; + param.pad_h = 1; + param.pad_w = 1; + param.nonlineMode = NonlineMode::RELU; + param.sparse = param::ConvBias::Sparse::GROUP; + + constexpr size_t RUN = 50; + Benchmarker benchmark0(handle()); + benchmark0.set_display(false); + benchmark0.set_param(param); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("F32STRD2")); + + auto opr = handle()->create_operator(); + opr->param() = param; + + param.format = param::ConvBias::Format::NCHW44; + Benchmarker benchmark1(handle()); + benchmark1.set_display(false); + benchmark1.set_param(param); + benchmark1.set_times(RUN); + benchmark1.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker("F32_CHANNEL_WISE_NCHW44")); + auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { + TensorLayout dst_layout; + opr->deduce_layout( + {{1, group * 4, h, w}, dtype::Int8()}, + {{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, + {{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + auto used0 = benchmark0.exec( + {{1, group * 4, h, w}, + {group * 4, 1, 1, kernel, kernel}, + {1, group * 4, 1, 1}, + {}, + {}}) / + RUN; + auto used1 = benchmark1.exec( + {{1, group, h, w, 4}, + {group, 1, 1, kernel, kernel, 4}, + {1, group, 1, 1, 4}, + {}, + {}}) / + RUN; + printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " + "nchw44: " + "%f ms %f GFlops " + "speedup: %f\n", + group, h, w, kernel, used0, computations / used0, used1, + computations / used1, used0 / used1); + }; + for (size_t group : {8, 16, 32, 64}) { + for (size_t kerenl : {2, 3, 5}) { + run(group, 112, 112, kerenl); + run(group, 56, 56, kerenl); + run(group, 48, 48, kerenl); + run(group, 28, 28, kerenl); + run(group, 14, 14, kerenl); + } + } + run(8, 112, 112, 3); + run(32, 56, 56, 3); + run(64, 28, 28, 3); + run(128, 14, 14, 3); +} + TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { constexpr size_t RUNS = 10; param::ConvBias param; @@ -320,6 +943,164 @@ TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { } } } + +TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_4x4) { +#if MEGDNN_AARCH64 + conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); +#elif MEGDNN_ARMV7 + conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); +#else + conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:2", handle(), 3, 4); +#endif +} + +void benchmark_winograd_nchw_vs_nchw44( + const char* algo_name0, const char* algo_name1, Handle* handle) { + using namespace conv_bias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args_nchw44; + std::vector args_nchw; + + auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t group, + NLMode nlmode) { + param::ConvBias param; + param.format = param::ConvBias::Format::NCHW44; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 1; + param.pad_w = 1; + param.nonlineMode = nlmode; + + if (group == 1) { + param.sparse = param::ConvBias::Sparse::DENSE; + args_nchw44.emplace_back( + param, TensorShape{n, ic / 4, h, w, 4}, + TensorShape{oc / 4, ic / 4, 3, 3, 4, 4}, TensorShape{}); + param.format = param::ConvBias::Format::NCHW; + args_nchw.emplace_back( + param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, 3, 3}, + TensorShape{}); + } else { + auto oc_per_group = oc / group; + auto ic_per_group = ic / group; + param.sparse = param::ConvBias::Sparse::GROUP; + args_nchw44.emplace_back( + param, TensorShape{n, ic_per_group / 4, h, w, 4}, + TensorShape{group, oc_per_group / 4, ic_per_group / 4, 3, 3, 4, 4}, + TensorShape{}); + param.format = param::ConvBias::Format::NCHW; + args_nchw.emplace_back( + param, TensorShape{n, ic, h, w}, + TensorShape{group, oc_per_group, ic_per_group, 3, 3}, + TensorShape{}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + for (auto nlmode : nonlinemode) + for (size_t n : {1}) + for (size_t group = 1; group <= 1; ++group) { + pack(n, 512, 512, 15, 15, group, nlmode); + pack(n, 512, 256, 15, 15, group, nlmode); + pack(n, 256, 256, 29, 29, group, nlmode); + pack(n, 256, 128, 29, 29, group, nlmode); + pack(n, 128, 128, 57, 57, group, nlmode); + pack(n, 128, 64, 57, 57, group, nlmode); + pack(n, 24, 24, 224, 224, group, nlmode); + pack(n, 64, 24, 123, 123, group, nlmode); + pack(n, 64, 64, 56, 56, group, nlmode); + pack(n, 128, 128, 28, 28, group, nlmode); + pack(n, 256, 256, 14, 14, group, nlmode); + pack(n, 512, 512, 7, 7, group, nlmode); + } + + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark_winograd_nchw(handle); + benchmark_winograd_nchw.set_display(false); + benchmark_winograd_nchw.set_times(RUN); + + Benchmarker benchmark_winograd_nchw44(handle); + benchmark_winograd_nchw44.set_display(false); + benchmark_winograd_nchw44.set_times(RUN); + + std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0); + std::string winograd_nchw44_algo_name = ssprintf("WINOGRAD_NCHW44:%s", algo_name1); + + for (size_t i = 0; i < args_nchw.size(); ++i) { + auto arg_nchw = args_nchw[i]; + auto arg_nchw44 = args_nchw44[i]; + + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg_nchw.param; + opr->deduce_layout( + {arg_nchw.src, dtype::Float32()}, {arg_nchw.filter, dtype::Float32()}, + {arg_nchw.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg_nchw.filter[1] * + arg_nchw.filter[2] * arg_nchw.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + benchmark_winograd_nchw.set_param(arg_nchw.param); + auto nchw_used = algo_benchmark( + benchmark_winograd_nchw, + {arg_nchw.src, arg_nchw.filter, {}, {}, {}}, + winograd_nchw_algo_name.c_str()) / + RUN; + + benchmark_winograd_nchw44.set_param(arg_nchw44.param); + auto nchw44_used = algo_benchmark( + benchmark_winograd_nchw44, + {arg_nchw44.src, arg_nchw44.filter, {}, {}, {}}, + winograd_nchw44_algo_name.c_str()) / + RUN; + + printf("%s %s: nchw: %f ms %f Gflops nchw44: %f ms %f GFlops " + "speedup: " + "%f\n", + arg_nchw.src.to_string().c_str(), arg_nchw.filter.to_string().c_str(), + nchw_used, computations / nchw_used, nchw44_used, + computations / nchw44_used, nchw_used / nchw44_used); + } +} + +TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_nchw_vs_nchw44( + "AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); +#elif MEGDNN_ARMV7 + benchmark_winograd_nchw_vs_nchw44( + "ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); +#else + benchmark_winograd_nchw_vs_nchw44( + "FB_GI_F32_MK4_4x8:4:2", "FB_GI_F32_MK4_4x8:4:2", handle()); +#endif +} + +TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_4x4) { +#if MEGDNN_AARCH64 + conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); +#elif MEGDNN_ARMV7 + conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); +#else + conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:6", handle(), 3, 4); +#endif +} + +TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_nchw_vs_nchw44( + "AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); +#elif MEGDNN_ARMV7 + benchmark_winograd_nchw_vs_nchw44( + "ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); +#else + benchmark_winograd_nchw_vs_nchw44( + "FB_GI_F32_MK4_4x8:4:6", "FB_GI_F32_MK4_4x8:4:6", handle()); +#endif +} + #endif } // namespace test