GitOrigin-RevId: ccf8b589be
release-1.10
@@ -1,725 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp32; | |||||
using namespace conv_stride1; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride1::do_conv_2x2_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
//! unroll of 2 | |||||
size_t ic = 0; | |||||
for (; ic + 1 < IC; ic += 2) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
const float* src_ptr1 = src_ptr + IW * IH; | |||||
float* outptr = dst; | |||||
const float* r00 = src_ptr; | |||||
const float* r01 = src_ptr + IW; | |||||
const float* r10 = src_ptr1; | |||||
const float* r11 = src_ptr1 + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
const float* k1 = k0 + 4; | |||||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); | |||||
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); | |||||
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); | |||||
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); | |||||
MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); | |||||
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); | |||||
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); | |||||
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, MEGDNN_SIMD_GET_LOW(_k0), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, MEGDNN_SIMD_GET_LOW(_k0), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||||
_sum, _r010, MEGDNN_SIMD_GET_HIGH(_k0), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||||
_sum, _r011, MEGDNN_SIMD_GET_HIGH(_k0), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, MEGDNN_SIMD_GET_LOW(_k1), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, MEGDNN_SIMD_GET_LOW(_k1), 1); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||||
_sum, _r110, MEGDNN_SIMD_GET_HIGH(_k1), 0); | |||||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||||
_sum, _r111, MEGDNN_SIMD_GET_HIGH(_k1), 1); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r00 += 4; | |||||
r01 += 4; | |||||
r10 += 4; | |||||
r11 += 4; | |||||
outptr += 4; | |||||
} | |||||
r00 += tail_step; | |||||
r01 += tail_step; | |||||
r10 += tail_step; | |||||
r11 += tail_step; | |||||
} | |||||
} | |||||
for (; ic < IC; ic++) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); | |||||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); | |||||
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); | |||||
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2; | |||||
_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); | |||||
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1); | |||||
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); | |||||
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); | |||||
_sum = MEGDNN_SIMD_ADD(_sum, _sum2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_3x3_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); | |||||
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); | |||||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); | |||||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); | |||||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||||
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||||
MEGDNN_SIMD_STOREU(outptr2, _sum3); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_5x5_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); | |||||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
MEGDNN_SIMD_STOREU(outptr2, _sum2); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
r5 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_7x7_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 | |||||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 | |||||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 | |||||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 | |||||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 | |||||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 | |||||
MEGDNN_SIMD_TYPE _r05 = MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 | |||||
MEGDNN_SIMD_TYPE _r06 = MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); | |||||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); | |||||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); | |||||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||||
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); | |||||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||||
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); | |||||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); | |||||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4); | |||||
MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | |||||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | |||||
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); | |||||
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); | |||||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); | |||||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); | |||||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); | |||||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
r6 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
#include "src/common/simd_macro/epilogue.h" | |||||
// vim: syntax=cpp.doxygen |
@@ -1,512 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "./do_conv_stride2.h" | |||||
#include "midout.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/arm_common/simd_macro/neon_helper.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using namespace fp32; | |||||
using namespace conv_stride2; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride2::do_conv_2x2_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); | |||||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
filter += 4; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_3x3_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); | |||||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); | |||||
MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); | |||||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); | |||||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_5x5_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
MEGDNN_SIMD_TYPE _r03 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
MEGDNN_SIMD_TYPE _r04 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_7x7_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
MEGDNN_SIMD_TYPE _r02 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
MEGDNN_SIMD_TYPE _r03 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
MEGDNN_SIMD_TYPE _r04 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
MEGDNN_SIMD_TYPE _r05 = | |||||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 | |||||
MEGDNN_SIMD_TYPE _r06 = | |||||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||||
MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); | |||||
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); | |||||
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||||
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); | |||||
MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); | |||||
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); | |||||
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; | |||||
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; | |||||
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; | |||||
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; | |||||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); | |||||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); | |||||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); | |||||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); | |||||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); | |||||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); | |||||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
r5 += 8; | |||||
r6 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -28,7 +28,6 @@ | |||||
#include "include/megdnn/oprs/nn.h" | #include "include/megdnn/oprs/nn.h" | ||||
#include "src/arm_common/conv_bias/f16/algos.h" | #include "src/arm_common/conv_bias/f16/algos.h" | ||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/int8/stride1.h" | #include "src/arm_common/conv_bias/int8/stride1.h" | ||||
#include "src/arm_common/conv_bias/int8/stride2.h" | #include "src/arm_common/conv_bias/int8/stride2.h" | ||||
#include "src/arm_common/conv_bias/quint8/stride1.h" | #include "src/arm_common/conv_bias/quint8/stride1.h" | ||||
@@ -69,14 +68,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | ||||
#endif | #endif | ||||
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | |||||
AlgoF32DirectNCHW44 f32_direct_nchw44; | |||||
AlgoF32Direct f32_direct; | |||||
AlgoF32DirectStride2 f32_direct_stride2; | |||||
AlgoF32DirectStride1 f32_direct_stride1; | |||||
AlgoI8x8x16Direct i8x8x16_direct; | AlgoI8x8x16Direct i8x8x16_direct; | ||||
AlgoI8x8x16Stride2 i8x8x16_stride2; | AlgoI8x8x16Stride2 i8x8x16_stride2; | ||||
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | ||||
@@ -127,14 +118,6 @@ public: | |||||
m_direct_algos.emplace_back(&i8x8x16_stride2); | m_direct_algos.emplace_back(&i8x8x16_stride2); | ||||
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | ||||
m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||||
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||||
m_direct_algos.emplace_back(&f32_direct_nchw44); | |||||
m_direct_algos.emplace_back(&f32_direct_stride1); | |||||
m_direct_algos.emplace_back(&f32_direct_stride2); | |||||
m_direct_algos.emplace_back(&f32_direct); | |||||
static CpuOprDelegationStorage<2> storage; | static CpuOprDelegationStorage<2> storage; | ||||
auto matmul_opr = storage.get<MatrixMul, 0>(); | auto matmul_opr = storage.get<MatrixMul, 0>(); | ||||
using MatmulFormat = param::MatrixMul::Format; | using MatmulFormat = param::MatrixMul::Format; | ||||
@@ -145,22 +128,6 @@ public: | |||||
if (is_fallback_or_naive(algo)) | if (is_fallback_or_naive(algo)) | ||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
//! uncomment this when low precision mode is done | //! uncomment this when low precision mode is done | ||||
#if 0 | #if 0 | ||||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | ||||
@@ -175,27 +142,6 @@ public: | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | m_winograd_algos.emplace_back(refhold.back().get()); | ||||
} | } | ||||
} | } | ||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type( | |||||
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_winograd_algos.emplace_back(refhold.back().get()); | |||||
} | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | ||||
@@ -49,15 +49,6 @@ private: | |||||
class AlgoS8DirectNCHWNCHW44; | class AlgoS8DirectNCHWNCHW44; | ||||
class AlgoQU8DirectStride1; | class AlgoQU8DirectStride1; | ||||
class AlgoQU8DirectStride2; | class AlgoQU8DirectStride2; | ||||
class AlgoFP32WinogradF23_4x4; | |||||
class AlgoFP32WinogradF63; | |||||
class AlgoFP32WinogradF63_4x4; | |||||
class AlgoFP32WinogradF54; | |||||
class AlgoFP32WinogradF45; | |||||
class AlgoFP32WinogradF23_4x4_NCHW44; | |||||
class AlgoFP32WinogradF63_4x4_NCHW44; | |||||
class AlgoFP32WinogradF73_4x4_NCHW44; | |||||
class AlgoS8ChanWiseStride1NCHW44; | class AlgoS8ChanWiseStride1NCHW44; | ||||
class AlgoS8ChanWiseStride2NCHW44; | class AlgoS8ChanWiseStride2NCHW44; | ||||
@@ -78,12 +69,6 @@ private: | |||||
class AlgoDotS8Direct_NCHW44; | class AlgoDotS8Direct_NCHW44; | ||||
#endif | #endif | ||||
class AlgoF32Direct; | |||||
class AlgoF32DirectStride1; | |||||
class AlgoF32DirectStride2; | |||||
class AlgoF32DirectNCHWNCHW44; | |||||
class AlgoF32ChannelWiseNCHW44; | |||||
class AlgoF32DirectNCHW44; | |||||
class AlgoI8x8x16Direct; | class AlgoI8x8x16Direct; | ||||
class AlgoI8x8x16Stride2; | class AlgoI8x8x16Stride2; | ||||
@@ -10,6 +10,8 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megbrain_build_config.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
#include "src/fallback/matrix_mul/opr_impl.h" | #include "src/fallback/matrix_mul/opr_impl.h" | ||||
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/block_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace { | |||||
// block_helper is used to calculate oh block size | |||||
static inline int l2_block_helper( | |||||
const int nthread, const int amount, const int size_per_unit) { | |||||
//! TODO: opt config or dynamic config l2_cache_size for different ARCH | |||||
constexpr int l2_cache_size = 256 * 1024; | |||||
const int block_per_thread = div_ceil(amount, nthread); | |||||
const int best_block = | |||||
std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); | |||||
const int max_block_num = div_ceil(block_per_thread, best_block); | |||||
const int min_block_num = std::max(max_block_num - 1, 1); | |||||
const int max_block = div_ceil(block_per_thread, max_block_num); | |||||
const int min_block = div_ceil(block_per_thread, min_block_num); | |||||
const int max_loss = std::abs(max_block_num * max_block - block_per_thread); | |||||
const int min_loss = std::abs(min_block_num * min_block - block_per_thread); | |||||
int block = max_loss > min_loss ? min_block : max_block; | |||||
return block; | |||||
} | |||||
} // namespace | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,23 +10,22 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||||
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/direct/multi_thread_common.h" | #include "src/fallback/conv_bias/direct/multi_thread_common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/direct.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
/* ======================= AlgoFP32WinogradF23_4x4 ======================== */ | /* ======================= AlgoFP32WinogradF23_4x4 ======================== */ | ||||
@@ -34,10 +33,10 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 0, 0) { | |||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | ||||
return false; | return false; | ||||
using Strategy = winograd::winograd_2x3_4x4_f; | |||||
using Strategy = winograd::winograd_gi_2x3_4x4_f; | |||||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | ||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | Strategy strategy(param.src_type, param.filter_type, param.dst_type); | ||||
auto&& matmul_param = | auto&& matmul_param = | ||||
@@ -62,8 +61,8 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF23_4x4, winograd::winograd_2x3_4x4_f, | |||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
AlgoFP32WinogradF23_4x4, winograd::winograd_gi_2x3_4x4_f, | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
/* ======================= AlgoFP32WinogradF63 ======================== */ | /* ======================= AlgoFP32WinogradF63 ======================== */ | ||||
@@ -71,7 +70,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 1, 0) { | |||||
using Strategy = winograd::winograd_6x3_1x1_f; | using Strategy = winograd::winograd_6x3_1x1_f; | ||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | Strategy strategy(param.src_type, param.filter_type, param.dst_type); | ||||
auto&& matmul_param = | auto&& matmul_param = | ||||
@@ -95,7 +94,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, | AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
/* ======================= AlgoFP32WinogradF54 ======================== */ | /* ======================= AlgoFP32WinogradF54 ======================== */ | ||||
@@ -103,7 +102,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 2, 0) { | |||||
using Strategy = winograd::winograd_5x4_1x1_f; | using Strategy = winograd::winograd_5x4_1x1_f; | ||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | Strategy strategy(param.src_type, param.filter_type, param.dst_type); | ||||
auto&& matmul_param = | auto&& matmul_param = | ||||
@@ -127,7 +126,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f, | AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
/* ======================= AlgoFP32WinogradF45 ======================== */ | /* ======================= AlgoFP32WinogradF45 ======================== */ | ||||
@@ -135,7 +134,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 3, 0) { | |||||
using Strategy = winograd::winograd_4x5_1x1_f; | using Strategy = winograd::winograd_4x5_1x1_f; | ||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | Strategy strategy(param.src_type, param.filter_type, param.dst_type); | ||||
auto&& matmul_param = | auto&& matmul_param = | ||||
@@ -159,7 +158,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f, | AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
/* ======================= AlgoFP32WinogradF63_4x4 ======================== */ | /* ======================= AlgoFP32WinogradF63_4x4 ======================== */ | ||||
@@ -167,7 +166,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 4, 0) { | |||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | ||||
return false; | return false; | ||||
using Strategy = winograd::winograd_6x3_4x4_f; | using Strategy = winograd::winograd_6x3_4x4_f; | ||||
@@ -197,7 +196,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, | AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
/* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ | /* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ | ||||
@@ -206,7 +205,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_arm_common_winograd_fp32, | |||||
megdnn_fallback_winograd_fp32, | |||||
midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { | midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { | ||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | ||||
return false; | return false; | ||||
@@ -236,7 +235,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44, | AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
/* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ | /* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ | ||||
@@ -245,7 +244,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_arm_common_winograd_fp32, | |||||
megdnn_fallback_winograd_fp32, | |||||
midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { | midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { | ||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | ||||
return false; | return false; | ||||
@@ -276,7 +275,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, | AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
/* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ | /* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ | ||||
@@ -284,7 +283,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_arm_common_winograd_fp32, | |||||
megdnn_fallback_winograd_fp32, | |||||
midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { | midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { | ||||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | ||||
return false; | return false; | ||||
@@ -314,14 +313,14 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | ||||
AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44, | AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44, | ||||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | |||||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_kimpl); | |||||
bool ConvBiasImpl::AlgoF32Direct::usable( | bool ConvBiasImpl::AlgoF32Direct::usable( | ||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 0) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
auto SH = fm.stride[0], SW = fm.stride[1]; | auto SH = fm.stride[0], SW = fm.stride[1]; | ||||
@@ -341,7 +340,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||||
return false; | return false; | ||||
} | } | ||||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { | size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
auto wbundle = fallback::MultithreadDirectConvCommon<float, float>::get_bundle( | auto wbundle = fallback::MultithreadDirectConvCommon<float, float>::get_bundle( | ||||
param, large_group); | param, large_group); | ||||
@@ -426,7 +425,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -435,7 +434,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||||
/* ===================== stride-1 algo ===================== */ | /* ===================== stride-1 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | bool ConvBiasImpl::AlgoF32DirectStride1::usable( | ||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | return param.filter_meta.format == param::ConvBias::Format::NCHW && | ||||
@@ -452,7 +451,7 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
auto bundle = | auto bundle = | ||||
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
@@ -548,7 +547,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 2) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -559,7 +558,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_ | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | bool ConvBiasImpl::AlgoF32DirectStride2::usable( | ||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 0) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | return param.filter_meta.format == param::ConvBias::Format::NCHW && | ||||
@@ -575,7 +574,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 1) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
auto bundle = | auto bundle = | ||||
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | ||||
@@ -670,7 +669,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 2) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/algos.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,11 +12,11 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | #include "src/fallback/matrix_mul/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoFP32WinogradF23_4x4( | AlgoFP32WinogradF23_4x4( | ||||
@@ -31,7 +31,7 @@ public: | |||||
} | } | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | ||||
@@ -50,7 +50,7 @@ public: | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | ||||
@@ -67,7 +67,7 @@ public: | |||||
} | } | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | ||||
@@ -86,7 +86,7 @@ public: | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F54_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | ||||
@@ -105,7 +105,7 @@ public: | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F45_FP32) | |||||
}; | }; | ||||
//===================== NCHW44 Winograd Support =====================// | //===================== NCHW44 Winograd Support =====================// | ||||
@@ -124,7 +124,7 @@ public: | |||||
} | } | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | ||||
@@ -142,7 +142,7 @@ public: | |||||
} | } | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | ||||
@@ -160,7 +160,7 @@ public: | |||||
} | } | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||||
}; | }; | ||||
// ================================================================= // | // ================================================================= // | ||||
@@ -180,7 +180,7 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | ||||
@@ -199,7 +199,7 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD1_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
@@ -218,7 +218,7 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD2_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | ||||
@@ -238,7 +238,7 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW44_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | ||||
@@ -258,7 +258,7 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | ||||
@@ -277,10 +277,10 @@ public: | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_CHWNWISE_NCHW44_F32) | |||||
}; | }; | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#undef MEGDNN_WINOGRAD_ALGO_FUN_DECLARE | #undef MEGDNN_WINOGRAD_ALGO_FUN_DECLARE |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,29 +10,22 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
#if defined(__ARM_FEATURE_FMA) | |||||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||||
#else | |||||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||||
#endif | |||||
template <int shift> | template <int shift> | ||||
static inline void shift_src(float32x4_t rsrc[3][4]) { | |||||
float32x4_t t[4]; | |||||
static inline void shift_src(GI_FLOAT32_t rsrc[3][4]) { | |||||
GI_FLOAT32_t t[4]; | |||||
t[0] = rsrc[0][(shift + 0) % 4]; | t[0] = rsrc[0][(shift + 0) % 4]; | ||||
t[1] = rsrc[0][(shift + 1) % 4]; | t[1] = rsrc[0][(shift + 1) % 4]; | ||||
@@ -63,9 +56,9 @@ static inline void shift_src(float32x4_t rsrc[3][4]) { | |||||
} | } | ||||
template <BiasMode bias_mode> | template <BiasMode bias_mode> | ||||
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { | |||||
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { | |||||
if (bias_mode == BiasMode::BIAS) { | if (bias_mode == BiasMode::BIAS) { | ||||
return vld1q_f32(bias); | |||||
return GiLoadFloat32(bias); | |||||
} else { | } else { | ||||
return init; | return init; | ||||
} | } | ||||
@@ -76,35 +69,35 @@ struct compute_element { | |||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
const float*& src0, const float*& src1, const float*& src2, float*& dst, | const float*& src0, const float*& src1, const float*& src2, float*& dst, | ||||
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], const Op& op) { | |||||
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], | |||||
GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||||
#define RSRC(i, j) rsrc[i][((j) + bw) % 4] | #define RSRC(i, j) rsrc[i][((j) + bw) % 4] | ||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | if (has_top) { | ||||
RSRC(0, 3) = vld1q_f32(src0 + 8); | |||||
RSRC(0, 3) = GiLoadFloat32(src0 + 8); | |||||
} | } | ||||
{ RSRC(1, 3) = vld1q_f32(src1 + 8); } | |||||
{ RSRC(1, 3) = GiLoadFloat32(src1 + 8); } | |||||
if (has_bottom) { | if (has_bottom) { | ||||
RSRC(2, 3) = vld1q_f32(src2 + 8); | |||||
RSRC(2, 3) = GiLoadFloat32(src2 + 8); | |||||
} | } | ||||
if (has_top) { | if (has_top) { | ||||
rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(0, 0), rfilter[0][0]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(0, 1), rfilter[0][1]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(0, 2), rfilter[0][2]); | |||||
} | } | ||||
{ | { | ||||
rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(1, 0), rfilter[1][0]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(1, 1), rfilter[1][1]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(1, 2), rfilter[1][2]); | |||||
} | } | ||||
if (has_bottom) { | if (has_bottom) { | ||||
rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(2, 0), rfilter[2][0]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(2, 1), rfilter[2][1]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(2, 2), rfilter[2][2]); | |||||
} | } | ||||
vst1q_f32(dst, op(rdst)); | |||||
GiStoreFloat32(dst, op(rdst)); | |||||
if (has_top) { | if (has_top) { | ||||
src0 += 4; | src0 += 4; | ||||
@@ -131,27 +124,27 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element_right { | struct compute_element_right { | ||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
float*& dst, const float*& bias, const float32x4_t& init, | |||||
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { | |||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||||
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | if (has_top) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0][0], rfilter[0][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][2]); | |||||
} | } | ||||
{ | { | ||||
rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1][0], rfilter[1][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][2]); | |||||
} | } | ||||
if (has_bottom) { | if (has_bottom) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2][0], rfilter[2][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][2]); | |||||
} | } | ||||
vst1q_f32(dst, op(rdst)); | |||||
GiStoreFloat32(dst, op(rdst)); | |||||
dst += 4; | dst += 4; | ||||
bias += 4; | bias += 4; | ||||
@@ -162,24 +155,24 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element_right_pad { | struct compute_element_right_pad { | ||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
float*& dst, const float*& bias, const float32x4_t& init, | |||||
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { | |||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||||
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | if (has_top) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][1]); | |||||
} | } | ||||
{ | { | ||||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][1]); | |||||
} | } | ||||
if (has_bottom) { | if (has_bottom) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][1]); | |||||
} | } | ||||
vst1q_f32(dst, op(rdst)); | |||||
GiStoreFloat32(dst, op(rdst)); | |||||
dst += 4; | dst += 4; | ||||
bias += 4; | bias += 4; | ||||
} | } | ||||
@@ -190,22 +183,22 @@ struct compute_row { | |||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
const float*& src0, const float*& src1, const float*& src2, float*& dst, | const float*& src0, const float*& src1, const float*& src2, float*& dst, | ||||
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], int W, const Op& op) { | |||||
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], | |||||
GI_FLOAT32_t rfilter[3][3], int W, const Op& op) { | |||||
if (has_top) { | if (has_top) { | ||||
rsrc[0][0] = vdupq_n_f32(0); | |||||
rsrc[0][1] = vld1q_f32(src0 + 0); | |||||
rsrc[0][2] = vld1q_f32(src0 + 4); | |||||
rsrc[0][0] = GiZeroFloat32(); | |||||
rsrc[0][1] = GiLoadFloat32(src0 + 0); | |||||
rsrc[0][2] = GiLoadFloat32(src0 + 4); | |||||
} | } | ||||
{ | { | ||||
rsrc[1][0] = vdupq_n_f32(0); | |||||
rsrc[1][1] = vld1q_f32(src1 + 0); | |||||
rsrc[1][2] = vld1q_f32(src1 + 4); | |||||
rsrc[1][0] = GiZeroFloat32(); | |||||
rsrc[1][1] = GiLoadFloat32(src1 + 0); | |||||
rsrc[1][2] = GiLoadFloat32(src1 + 4); | |||||
} | } | ||||
if (has_bottom) { | if (has_bottom) { | ||||
rsrc[2][0] = vdupq_n_f32(0); | |||||
rsrc[2][1] = vld1q_f32(src2 + 0); | |||||
rsrc[2][2] = vld1q_f32(src2 + 4); | |||||
rsrc[2][0] = GiZeroFloat32(); | |||||
rsrc[2][1] = GiLoadFloat32(src2 + 0); | |||||
rsrc[2][2] = GiLoadFloat32(src2 + 4); | |||||
} | } | ||||
int w = 0; | int w = 0; | ||||
@@ -256,27 +249,27 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( | |||||
int W) { | int W) { | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
const float* src0 = src - W * 4; | const float* src0 = src - W * 4; | ||||
const float* src1 = src; | const float* src1 = src; | ||||
const float* src2 = src + W * 4; | const float* src2 = src + W * 4; | ||||
float32x4_t rfilter[3][3]; | |||||
rfilter[0][0] = vld1q_f32(filter + 0); | |||||
rfilter[0][1] = vld1q_f32(filter + 4); | |||||
rfilter[0][2] = vld1q_f32(filter + 8); | |||||
rfilter[1][0] = vld1q_f32(filter + 12); | |||||
rfilter[1][1] = vld1q_f32(filter + 16); | |||||
rfilter[1][2] = vld1q_f32(filter + 20); | |||||
rfilter[2][0] = vld1q_f32(filter + 24); | |||||
rfilter[2][1] = vld1q_f32(filter + 28); | |||||
rfilter[2][2] = vld1q_f32(filter + 32); | |||||
float32x4_t rsrc[3][4]; | |||||
GI_FLOAT32_t rfilter[3][3]; | |||||
rfilter[0][0] = GiLoadFloat32(filter + 0); | |||||
rfilter[0][1] = GiLoadFloat32(filter + 4); | |||||
rfilter[0][2] = GiLoadFloat32(filter + 8); | |||||
rfilter[1][0] = GiLoadFloat32(filter + 12); | |||||
rfilter[1][1] = GiLoadFloat32(filter + 16); | |||||
rfilter[1][2] = GiLoadFloat32(filter + 20); | |||||
rfilter[2][0] = GiLoadFloat32(filter + 24); | |||||
rfilter[2][1] = GiLoadFloat32(filter + 28); | |||||
rfilter[2][2] = GiLoadFloat32(filter + 32); | |||||
GI_FLOAT32_t rsrc[3][4]; | |||||
compute_row<false, true, bias_mode>::call( | compute_row<false, true, bias_mode>::call( | ||||
src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); | src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,11 +12,11 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace channel_wise_nchw44_float { | namespace channel_wise_nchw44_float { | ||||
template <BiasMode bias_mode, typename Op> | template <BiasMode bias_mode, typename Op> | ||||
@@ -25,7 +25,7 @@ void do_conv_kern_3x3_stride1_padding1( | |||||
int W); | int W); | ||||
} // namespace channel_wise_nchw44_float | } // namespace channel_wise_nchw44_float | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,29 +10,22 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
#if defined(__ARM_FEATURE_FMA) | |||||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||||
#else | |||||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||||
#endif | |||||
template <int shift> | template <int shift> | ||||
static inline void shift_src(float32x4_t rsrc[6]) { | |||||
float32x4_t t[6]; | |||||
static inline void shift_src(GI_FLOAT32_t rsrc[6]) { | |||||
GI_FLOAT32_t t[6]; | |||||
t[0] = rsrc[(shift + 0) % 6]; | t[0] = rsrc[(shift + 0) % 6]; | ||||
t[1] = rsrc[(shift + 1) % 6]; | t[1] = rsrc[(shift + 1) % 6]; | ||||
@@ -48,18 +41,18 @@ static inline void shift_src(float32x4_t rsrc[6]) { | |||||
rsrc[5] = t[5]; | rsrc[5] = t[5]; | ||||
} | } | ||||
static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { | |||||
rfilter[0] = vld1q_f32(filter + 0); | |||||
rfilter[1] = vld1q_f32(filter + 4); | |||||
rfilter[2] = vld1q_f32(filter + 8); | |||||
rfilter[3] = vld1q_f32(filter + 12); | |||||
rfilter[4] = vld1q_f32(filter + 16); | |||||
static inline void load_filter(const float* filter, GI_FLOAT32_t rfilter[5]) { | |||||
rfilter[0] = GiLoadFloat32(filter + 0); | |||||
rfilter[1] = GiLoadFloat32(filter + 4); | |||||
rfilter[2] = GiLoadFloat32(filter + 8); | |||||
rfilter[3] = GiLoadFloat32(filter + 12); | |||||
rfilter[4] = GiLoadFloat32(filter + 16); | |||||
} | } | ||||
template <BiasMode bias_mode> | template <BiasMode bias_mode> | ||||
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { | |||||
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { | |||||
if (bias_mode == BiasMode::BIAS) { | if (bias_mode == BiasMode::BIAS) { | ||||
return vld1q_f32(bias); | |||||
return GiLoadFloat32(bias); | |||||
} else { | } else { | ||||
return init; | return init; | ||||
} | } | ||||
@@ -69,27 +62,28 @@ template <int BW, int bw, BiasMode bias_mode, bool need_load_bias, bool need_do_ | |||||
struct compute_element { | struct compute_element { | ||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
const float*& src, float*& dst, const float*& bias, const float32x4_t& init, | |||||
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { | |||||
const float*& src, float*& dst, const float*& bias, | |||||
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], | |||||
const Op& op) { | |||||
#define RSRC(i) rsrc[((i) + bw) % 6] | #define RSRC(i) rsrc[((i) + bw) % 6] | ||||
float32x4_t rdst; | |||||
GI_FLOAT32_t rdst; | |||||
if (need_load_bias) { | if (need_load_bias) { | ||||
rdst = load_bias<bias_mode>(bias, init); | rdst = load_bias<bias_mode>(bias, init); | ||||
} else { | } else { | ||||
rdst = vld1q_f32(dst); | |||||
rdst = GiLoadFloat32(dst); | |||||
} | } | ||||
RSRC(5) = vld1q_f32(src + 12); | |||||
RSRC(5) = GiLoadFloat32(src + 12); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(0), rfilter[0]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(1), rfilter[1]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(2), rfilter[2]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(3), rfilter[3]); | |||||
rdst = GiMlaqFloat32(rdst, RSRC(4), rfilter[4]); | |||||
if (need_do_op) { | if (need_do_op) { | ||||
rdst = op(rdst); | rdst = op(rdst); | ||||
} | } | ||||
vst1q_f32(dst, rdst); | |||||
GiStoreFloat32(dst, rdst); | |||||
src += 4; | src += 4; | ||||
dst += 4; | dst += 4; | ||||
@@ -110,29 +104,29 @@ template <size_t padding, BiasMode bias_mode, bool need_load_bias, bool need_do_ | |||||
struct compute_element_right { | struct compute_element_right { | ||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
float*& dst, const float*& bias, const float32x4_t& init, | |||||
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { | |||||
float32x4_t rdst; | |||||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||||
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], const Op& op) { | |||||
GI_FLOAT32_t rdst; | |||||
if (need_load_bias) { | if (need_load_bias) { | ||||
rdst = load_bias<bias_mode>(bias, init); | rdst = load_bias<bias_mode>(bias, init); | ||||
} else { | } else { | ||||
rdst = vld1q_f32(dst); | |||||
rdst = GiLoadFloat32(dst); | |||||
} | } | ||||
rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[0 + padding], rfilter[0]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[1 + padding], rfilter[1]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[2 + padding], rfilter[2]); | |||||
if (padding < 2) { | if (padding < 2) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[3 + padding], rfilter[3]); | |||||
} | } | ||||
if (padding < 1) { | if (padding < 1) { | ||||
rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]); | |||||
rdst = GiMlaqFloat32(rdst, rsrc[4 + padding], rfilter[4]); | |||||
} | } | ||||
if (need_do_op) { | if (need_do_op) { | ||||
rdst = op(rdst); | rdst = op(rdst); | ||||
} | } | ||||
vst1q_f32(dst, rdst); | |||||
GiStoreFloat32(dst, rdst); | |||||
dst += 4; | dst += 4; | ||||
bias += 4; | bias += 4; | ||||
@@ -143,13 +137,13 @@ template <BiasMode bias_mode, bool need_load_bias, bool need_do_op> | |||||
struct compute_row_src_1x5 { | struct compute_row_src_1x5 { | ||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
const float* src, float* dst, const float* bias, const float32x4_t& init, | |||||
float32x4_t rsrc[6], float32x4_t rfilter[5], int W, const Op& op) { | |||||
rsrc[0] = vdupq_n_f32(0); | |||||
rsrc[1] = vdupq_n_f32(0); | |||||
rsrc[2] = vld1q_f32(src + 0); | |||||
rsrc[3] = vld1q_f32(src + 4); | |||||
rsrc[4] = vld1q_f32(src + 8); | |||||
const float* src, float* dst, const float* bias, const GI_FLOAT32_t& init, | |||||
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], int W, const Op& op) { | |||||
rsrc[0] = GiZeroFloat32(); | |||||
rsrc[1] = GiZeroFloat32(); | |||||
rsrc[2] = GiLoadFloat32(src + 0); | |||||
rsrc[3] = GiLoadFloat32(src + 4); | |||||
rsrc[4] = GiLoadFloat32(src + 8); | |||||
int w = 0; | int w = 0; | ||||
@@ -190,8 +184,8 @@ struct compute_row { | |||||
template <typename Op> | template <typename Op> | ||||
static inline void call( | static inline void call( | ||||
const float*& src, float*& dst, const float* filter, const float*& bias, | const float*& src, float*& dst, const float* filter, const float*& bias, | ||||
const float32x4_t& init, float32x4_t rsrc[6], float32x4_t rfilter[5], int W, | |||||
const Op& op) { | |||||
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], | |||||
int W, const Op& op) { | |||||
if (top_padding < 1) { | if (top_padding < 1) { | ||||
load_filter(filter + 0, rfilter); | load_filter(filter + 0, rfilter); | ||||
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call( | compute_row_src_1x5<bias_mode, top_padding == 0, false>::call( | ||||
@@ -235,13 +229,13 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( | |||||
int W) { | int W) { | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
float32x4_t rsrc[6]; | |||||
float32x4_t rfilter[5]; | |||||
GI_FLOAT32_t rsrc[6]; | |||||
GI_FLOAT32_t rfilter[5]; | |||||
compute_row<2, 0, bias_mode>::call( | compute_row<2, 0, bias_mode>::call( | ||||
src, dst, filter, bias, init, rsrc, rfilter, W, op); | src, dst, filter, bias, init, rsrc, rfilter, W, op); |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,11 +12,11 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace channel_wise_nchw44_float { | namespace channel_wise_nchw44_float { | ||||
template <BiasMode bias_mode, typename Op> | template <BiasMode bias_mode, typename Op> | ||||
@@ -25,7 +25,7 @@ void do_conv_kern_5x5_stride1_padding2( | |||||
int W); | int W); | ||||
} // namespace channel_wise_nchw44_float | } // namespace channel_wise_nchw44_float | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,14 +10,14 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
using conv_fun = std::function<void( | using conv_fun = std::function<void( | ||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp | |||||
* \file dnn/src/fallback/conv_bias/int8/direct.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,29 +10,28 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
template <int size> | template <int size> | ||||
void load_vec(float32x4_t* dst, const float* src); | |||||
void load_vec(GI_FLOAT32_t* dst, const float* src); | |||||
#define cb(i) dst[i] = vld1q_f32(src + i * 4); | |||||
#define LOAD_MACRO(n) \ | |||||
template <> \ | |||||
inline void load_vec<n>(float32x4_t * dst, const float* src) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||||
#define cb(i) dst[i] = GiLoadFloat32(src + i * 4); | |||||
#define LOAD_MACRO(n) \ | |||||
template <> \ | |||||
inline void load_vec<n>(GI_FLOAT32_t * dst, const float* src) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||||
} | } | ||||
LOAD_MACRO(2); | LOAD_MACRO(2); | ||||
LOAD_MACRO(3); | LOAD_MACRO(3); | ||||
@@ -46,14 +45,14 @@ LOAD_MACRO(9); | |||||
#undef LOAD_MACRO | #undef LOAD_MACRO | ||||
template <int size> | template <int size> | ||||
void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter); | |||||
void compute_vec(GI_FLOAT32_t& dst, GI_FLOAT32_t* src, GI_FLOAT32_t* filter); | |||||
#define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]); | |||||
#define COMPUTE_MACRO(n) \ | |||||
template <> \ | |||||
inline void compute_vec<n>( \ | |||||
float32x4_t & dst, float32x4_t * src, float32x4_t * filter) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||||
#define cb(i) dst = GiMlaqFloat32(dst, src[i], filter[i]); | |||||
#define COMPUTE_MACRO(n) \ | |||||
template <> \ | |||||
inline void compute_vec<n>( \ | |||||
GI_FLOAT32_t & dst, GI_FLOAT32_t * src, GI_FLOAT32_t * filter) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||||
} | } | ||||
COMPUTE_MACRO(2); | COMPUTE_MACRO(2); | ||||
COMPUTE_MACRO(3); | COMPUTE_MACRO(3); | ||||
@@ -64,20 +63,20 @@ COMPUTE_MACRO(5); | |||||
template <BiasMode bias_mode, int size> | template <BiasMode bias_mode, int size> | ||||
struct load_bias_vec; | struct load_bias_vec; | ||||
#define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4); | |||||
#define cb_bias(i) dst[i] = GiLoadFloat32((bptr) + i * 4); | |||||
#define cb_init(i) dst[i] = init; | #define cb_init(i) dst[i] = init; | ||||
#define INIT_BIAS_MACRO(n) \ | |||||
template <BiasMode bias_mode> \ | |||||
struct load_bias_vec<bias_mode, n> { \ | |||||
static void impl( \ | |||||
float32x4_t* dst, const float32x4_t& init, const float* bptr) { \ | |||||
if (bias_mode == BiasMode::BIAS) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb_bias); \ | |||||
} else { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb_init); \ | |||||
} \ | |||||
} \ | |||||
#define INIT_BIAS_MACRO(n) \ | |||||
template <BiasMode bias_mode> \ | |||||
struct load_bias_vec<bias_mode, n> { \ | |||||
static void impl( \ | |||||
GI_FLOAT32_t* dst, const GI_FLOAT32_t& init, const float* bptr) { \ | |||||
if (bias_mode == BiasMode::BIAS) { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb_bias); \ | |||||
} else { \ | |||||
UNROLL_CALL_NOWRAPPER(n, cb_init); \ | |||||
} \ | |||||
} \ | |||||
}; | }; | ||||
INIT_BIAS_MACRO(1); | INIT_BIAS_MACRO(1); | ||||
@@ -91,7 +90,7 @@ INIT_BIAS_MACRO(4); | |||||
#define COMPUTE_PADDING_KERNEL() \ | #define COMPUTE_PADDING_KERNEL() \ | ||||
do { \ | do { \ | ||||
int iw = ow * stride - PW; \ | int iw = ow * stride - PW; \ | ||||
float32x4_t result; \ | |||||
GI_FLOAT32_t result; \ | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4 + ow * 4); \ | load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4 + ow * 4); \ | ||||
for (int kh = 0; kh < fh; kh++) { \ | for (int kh = 0; kh < fh; kh++) { \ | ||||
if (kh + ih < 0 || kh + ih >= static_cast<int>(IH)) \ | if (kh + ih < 0 || kh + ih >= static_cast<int>(IH)) \ | ||||
@@ -100,7 +99,8 @@ INIT_BIAS_MACRO(4); | |||||
if (kw + iw < 0 || kw + iw >= static_cast<int>(IW)) \ | if (kw + iw < 0 || kw + iw >= static_cast<int>(IW)) \ | ||||
continue; \ | continue; \ | ||||
const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ | const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ | ||||
result = vmlaq_f32(result, kernel[kh * fh + kw], vld1q_f32(sptr)); \ | |||||
result = GiMlaqFloat32( \ | |||||
result, kernel[kh * fh + kw], GiLoadFloat32(sptr)); \ | |||||
} \ | } \ | ||||
} \ | } \ | ||||
float* output = dst + oh * OW * 4 + ow * 4; \ | float* output = dst + oh * OW * 4 + ow * 4; \ | ||||
@@ -113,7 +113,7 @@ struct PaddingCompute { | |||||
const float* src, const float* bias, float* dst, const int fh, | const float* src, const float* bias, float* dst, const int fh, | ||||
const int stride, const size_t IH, const size_t IW, const size_t OH, | const int stride, const size_t IH, const size_t IW, const size_t OH, | ||||
const size_t OW, const size_t PH, const size_t PW, | const size_t OW, const size_t PH, const size_t PW, | ||||
const float32x4_t* kernel, const float32x4_t& init) { | |||||
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { | |||||
size_t oh_start = (PH + stride - 1) / stride; | size_t oh_start = (PH + stride - 1) / stride; | ||||
size_t ow_start = (PW + stride - 1) / stride; | size_t ow_start = (PW + stride - 1) / stride; | ||||
size_t oh_end = (IH + PH - fh) / stride + 1; | size_t oh_end = (IH + PH - fh) / stride + 1; | ||||
@@ -148,7 +148,7 @@ struct PaddingComputeK3P1 { | |||||
static void compute( | static void compute( | ||||
const float* src, const float* bias, float* dst, const size_t stride, | const float* src, const float* bias, float* dst, const size_t stride, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const float32x4_t* kernel, const float32x4_t& init) { | |||||
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { | |||||
constexpr size_t PH = 1, PW = 1, FH = 3; | constexpr size_t PH = 1, PW = 1, FH = 3; | ||||
size_t oh_start = (PH + stride - 1) / stride; | size_t oh_start = (PH + stride - 1) / stride; | ||||
size_t ow_start = (PW + stride - 1) / stride; | size_t ow_start = (PW + stride - 1) / stride; | ||||
@@ -162,39 +162,39 @@ struct PaddingComputeK3P1 { | |||||
Op op; | Op op; | ||||
// line one left | // line one left | ||||
{ | { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias); | load_bias_vec<bias_mode, 1>::impl(&result, init, bias); | ||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(src)); | |||||
result = vmlaq_f32(result, kernel[5], vld1q_f32(src + 4)); | |||||
result = vmlaq_f32(result, kernel[7], vld1q_f32(src + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[8], vld1q_f32(src + IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(src)); | |||||
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(src + 4)); | |||||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(src + IW * 4)); | |||||
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(src + IW * 4 + 4)); | |||||
float* output = dst; | float* output = dst; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
// line one mid | // line one mid | ||||
for (size_t ow = ow_start; ow < ow_end; ow++) { | for (size_t ow = ow_start; ow < ow_end; ow++) { | ||||
int iw = ow * stride - PW; | int iw = ow * stride - PW; | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + ow * 4); | load_bias_vec<bias_mode, 1>::impl(&result, init, bias + ow * 4); | ||||
const float* sptr = src + iw * 4; | const float* sptr = src + iw * 4; | ||||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + 8)); | |||||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + IW * 4 + 8)); | |||||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(sptr + 8)); | |||||
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(sptr + IW * 4 + 8)); | |||||
float* output = dst + ow * 4; | float* output = dst + ow * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
// line one right | // line one right | ||||
if (OW != ow_end) { | if (OW != ow_end) { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + (OW - 1) * 4); | load_bias_vec<bias_mode, 1>::impl(&result, init, bias + (OW - 1) * 4); | ||||
const float* sptr = src + (ow_end * stride - PW) * 4; | const float* sptr = src + (ow_end * stride - PW) * 4; | ||||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
float* output = dst + ow_end * 4; | float* output = dst + ow_end * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
@@ -203,30 +203,36 @@ struct PaddingComputeK3P1 { | |||||
int ih = oh * stride - PH; | int ih = oh * stride - PH; | ||||
// left | // left | ||||
{ | { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | ||||
const float* sptr = src + ih * IW * 4; | const float* sptr = src + ih * IW * 4; | ||||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4)); | |||||
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + 2 * IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[8], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); | |||||
float* output = dst + oh * OW * 4; | float* output = dst + oh * OW * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
// right | // right | ||||
if (OW != ow_end) { | if (OW != ow_end) { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&result, init, bias + oh * OW * 4 + (OW - 1) * 4); | &result, init, bias + oh * OW * 4 + (OW - 1) * 4); | ||||
const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4; | const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4; | ||||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + 2 * IW * 4)); | |||||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[6], GiLoadFloat32(sptr + 2 * IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); | |||||
float* output = dst + oh * OW * 4 + ow_end * 4; | float* output = dst + oh * OW * 4 + ow_end * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
@@ -235,43 +241,47 @@ struct PaddingComputeK3P1 { | |||||
if (OH != oh_end) { | if (OH != oh_end) { | ||||
size_t oh = OH - 1; | size_t oh = OH - 1; | ||||
{ | { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | ||||
const float* sptr = src + (oh_end * stride - PH) * IW * 4; | const float* sptr = src + (oh_end * stride - PH) * IW * 4; | ||||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
float* output = dst + oh_end * OW * 4; | float* output = dst + oh_end * OW * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
// last line mid | // last line mid | ||||
for (size_t ow = ow_start; ow < ow_end; ow++) { | for (size_t ow = ow_start; ow < ow_end; ow++) { | ||||
int iw = ow * stride - PW; | int iw = ow * stride - PW; | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&result, init, bias + oh * OW * 4 + ow * 4); | &result, init, bias + oh * OW * 4 + ow * 4); | ||||
const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4; | const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4; | ||||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8)); | |||||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 8)); | |||||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 8)); | |||||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 8)); | |||||
float* output = dst + oh_end * OW * 4 + ow * 4; | float* output = dst + oh_end * OW * 4 + ow * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
// last line right | // last line right | ||||
if (OW != ow_end) { | if (OW != ow_end) { | ||||
float32x4_t result; | |||||
GI_FLOAT32_t result; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&result, init, bias + oh * OW * 4 + (OW - 1) * 4); | &result, init, bias + oh * OW * 4 + (OW - 1) * 4); | ||||
const float* sptr = src + (oh_end * stride - PH) * IW * 4 + | const float* sptr = src + (oh_end * stride - PH) * IW * 4 + | ||||
(ow_end * stride - PW) * 4; | (ow_end * stride - PW) * 4; | ||||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||||
result = GiMlaqFloat32( | |||||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||||
float* output = dst + oh_end * OW * 4 + ow_end * 4; | float* output = dst + oh_end * OW * 4 + ow_end * 4; | ||||
op(result, output); | op(result, output); | ||||
} | } | ||||
@@ -286,12 +296,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
float32x4_t kernel[4]; | |||||
GI_FLOAT32_t kernel[4]; | |||||
load_vec<4>(kernel, filter); | load_vec<4>(kernel, filter); | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
size_t oh_start = PH; | size_t oh_start = PH; | ||||
size_t ow_start = PW; | size_t ow_start = PW; | ||||
@@ -315,12 +325,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||||
size_t iw = ow - ow_start; | size_t iw = ow - ow_start; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][4]; | |||||
GI_FLOAT32_t dst_v[2][4]; | |||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[3][5]; | |||||
GI_FLOAT32_t src_v[3][5]; | |||||
load_vec<5>(src_v[0], input); | load_vec<5>(src_v[0], input); | ||||
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | ||||
load_vec<5>(src_v[1], input + IW * 4); | load_vec<5>(src_v[1], input + IW * 4); | ||||
@@ -338,12 +348,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||||
size_t iw = ow - ow_start; | size_t iw = ow - ow_start; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2]; | |||||
GI_FLOAT32_t dst_v[2]; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[3][2]; | |||||
GI_FLOAT32_t src_v[3][2]; | |||||
load_vec<2>(src_v[0], input); | load_vec<2>(src_v[0], input); | ||||
compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]); | compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]); | ||||
load_vec<2>(src_v[1], input + IW * 4); | load_vec<2>(src_v[1], input + IW * 4); | ||||
@@ -363,10 +373,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||||
size_t iw = ow - ow_start; | size_t iw = ow - ow_start; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[1][4]; | |||||
GI_FLOAT32_t dst_v[1][4]; | |||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
load_vec<5>(src_v[0], input); | load_vec<5>(src_v[0], input); | ||||
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | ||||
load_vec<5>(src_v[1], input + IW * 4); | load_vec<5>(src_v[1], input + IW * 4); | ||||
@@ -379,10 +389,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||||
size_t iw = ow - ow_start; | size_t iw = ow - ow_start; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][2]; | |||||
GI_FLOAT32_t src_v[2][2]; | |||||
load_vec<2>(src_v[0], input); | load_vec<2>(src_v[0], input); | ||||
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | ||||
load_vec<2>(src_v[1], input + IW * 4); | load_vec<2>(src_v[1], input + IW * 4); | ||||
@@ -405,12 +415,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
return; | return; | ||||
} | } | ||||
float32x4_t kernel[9]; | |||||
GI_FLOAT32_t kernel[9]; | |||||
load_vec<9>(kernel, filter); | load_vec<9>(kernel, filter); | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
size_t oh_start = PH; | size_t oh_start = PH; | ||||
size_t ow_start = PW; | size_t ow_start = PW; | ||||
@@ -428,12 +438,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][4]; | |||||
GI_FLOAT32_t dst_v[2][4]; | |||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][6]; | |||||
GI_FLOAT32_t src_v[2][6]; | |||||
load_vec<6>(src_v[0], input); | load_vec<6>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | ||||
compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]); | compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]); | ||||
@@ -472,12 +482,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2]; | |||||
GI_FLOAT32_t dst_v[2]; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][3]; | |||||
GI_FLOAT32_t src_v[2][3]; | |||||
load_vec<3>(src_v[0], input); | load_vec<3>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | ||||
load_vec<3>(src_v[1], input + IW * 4); | load_vec<3>(src_v[1], input + IW * 4); | ||||
@@ -500,10 +510,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[4]; | |||||
GI_FLOAT32_t dst_v[4]; | |||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][6]; | |||||
GI_FLOAT32_t src_v[2][6]; | |||||
load_vec<6>(src_v[0], input); | load_vec<6>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | ||||
compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]); | compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]); | ||||
@@ -526,10 +536,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[3][3]; | |||||
GI_FLOAT32_t src_v[3][3]; | |||||
load_vec<3>(src_v[0], input); | load_vec<3>(src_v[0], input); | ||||
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | ||||
load_vec<3>(src_v[1], input + IW * 4); | load_vec<3>(src_v[1], input + IW * 4); | ||||
@@ -553,9 +563,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
} | } | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
size_t oh_start = PH; | size_t oh_start = PH; | ||||
size_t ow_start = PW; | size_t ow_start = PW; | ||||
@@ -564,7 +574,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
if (PH || PW) { | if (PH || PW) { | ||||
PaddingCompute<bias_mode, Op>::compute( | PaddingCompute<bias_mode, Op>::compute( | ||||
src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW, | src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW, | ||||
reinterpret_cast<const float32x4_t*>(filter), init); | |||||
reinterpret_cast<const GI_FLOAT32_t*>(filter), init); | |||||
} | } | ||||
size_t oh = oh_start; | size_t oh = oh_start; | ||||
for (; oh + 1 < oh_end; oh += 2) { | for (; oh + 1 < oh_end; oh += 2) { | ||||
@@ -574,13 +584,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][2]; | |||||
GI_FLOAT32_t dst_v[2][2]; | |||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t kernel[2][5]; | |||||
float32x4_t src_v[2][6]; | |||||
GI_FLOAT32_t kernel[2][5]; | |||||
GI_FLOAT32_t src_v[2][6]; | |||||
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | #define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | ||||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | load_vec<5>(kernel0, filter + i * 5 * 4); \ | ||||
load_vec<6>(src, input + i * IW * 4); \ | load_vec<6>(src, input + i * IW * 4); \ | ||||
@@ -613,13 +623,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][1]; | |||||
GI_FLOAT32_t dst_v[2][1]; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t kernel[2][5]; | |||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t kernel[2][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | #define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | ||||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | load_vec<5>(kernel0, filter + i * 5 * 4); \ | ||||
load_vec<6>(src, input + i * IW * 4); \ | load_vec<6>(src, input + i * IW * 4); \ | ||||
@@ -652,11 +662,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[1][2]; | |||||
GI_FLOAT32_t dst_v[1][2]; | |||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t kernel[2][5]; | |||||
float32x4_t src_v[2][6]; | |||||
GI_FLOAT32_t kernel[2][5]; | |||||
GI_FLOAT32_t src_v[2][6]; | |||||
#define COMPUTE_5X5_2(i, dst, src, kernel) \ | #define COMPUTE_5X5_2(i, dst, src, kernel) \ | ||||
load_vec<5>(kernel, filter + i * 5 * 4); \ | load_vec<5>(kernel, filter + i * 5 * 4); \ | ||||
load_vec<6>(src, input + i * IW * 4); \ | load_vec<6>(src, input + i * IW * 4); \ | ||||
@@ -679,11 +689,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
size_t iw = ow - PW; | size_t iw = ow - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t kernel[2][5]; | |||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t kernel[2][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
#define COMPUTE_5X5_1(i, dst, src, kernel) \ | #define COMPUTE_5X5_1(i, dst, src, kernel) \ | ||||
load_vec<5>(kernel, filter + i * 5 * 4); \ | load_vec<5>(kernel, filter + i * 5 * 4); \ | ||||
load_vec<6>(src, input + i * IW * 4); \ | load_vec<6>(src, input + i * IW * 4); \ | ||||
@@ -709,12 +719,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
float32x4_t kernel[4]; | |||||
GI_FLOAT32_t kernel[4]; | |||||
load_vec<4>(kernel, filter); | load_vec<4>(kernel, filter); | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
size_t oh_start = (PH + 1) / 2; | size_t oh_start = (PH + 1) / 2; | ||||
size_t ow_start = (PW + 1) / 2; | size_t ow_start = (PW + 1) / 2; | ||||
@@ -737,10 +747,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[4]; | |||||
GI_FLOAT32_t dst_v[4]; | |||||
load_bias_vec<bias_mode, 4>::impl( | load_bias_vec<bias_mode, 4>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][8]; | |||||
GI_FLOAT32_t src_v[2][8]; | |||||
load_vec<8>(src_v[0], input); | load_vec<8>(src_v[0], input); | ||||
COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); | COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); | ||||
load_vec<8>(src_v[1], input + IW * 4); | load_vec<8>(src_v[1], input + IW * 4); | ||||
@@ -753,10 +763,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][2]; | |||||
GI_FLOAT32_t src_v[2][2]; | |||||
load_vec<2>(src_v[0], input); | load_vec<2>(src_v[0], input); | ||||
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | ||||
load_vec<2>(src_v[1], input + IW * 4); | load_vec<2>(src_v[1], input + IW * 4); | ||||
@@ -773,12 +783,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
float32x4_t kernel[9]; | |||||
GI_FLOAT32_t kernel[9]; | |||||
load_vec<9>(kernel, filter); | load_vec<9>(kernel, filter); | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
size_t oh_start = (PH + 1) / 2; | size_t oh_start = (PH + 1) / 2; | ||||
size_t ow_start = (PW + 1) / 2; | size_t ow_start = (PW + 1) / 2; | ||||
@@ -799,12 +809,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][2]; | |||||
GI_FLOAT32_t dst_v[2][2]; | |||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
load_vec<5>(src_v[0], input); | load_vec<5>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | ||||
compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]); | compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]); | ||||
@@ -830,12 +840,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2]; | |||||
GI_FLOAT32_t dst_v[2]; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t src_v[2][3]; | |||||
GI_FLOAT32_t src_v[2][3]; | |||||
load_vec<3>(src_v[0], input); | load_vec<3>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | ||||
load_vec<3>(src_v[1], input + IW * 4); | load_vec<3>(src_v[1], input + IW * 4); | ||||
@@ -859,10 +869,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2]; | |||||
GI_FLOAT32_t dst_v[2]; | |||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[3][5]; | |||||
GI_FLOAT32_t src_v[3][5]; | |||||
load_vec<5>(src_v[0], input); | load_vec<5>(src_v[0], input); | ||||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | ||||
compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]); | compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]); | ||||
@@ -878,10 +888,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||||
size_t iw = ow * 2 - PW; | size_t iw = ow * 2 - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t src_v[3][3]; | |||||
GI_FLOAT32_t src_v[3][3]; | |||||
load_vec<3>(src_v[0], input); | load_vec<3>(src_v[0], input); | ||||
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | ||||
load_vec<3>(src_v[1], input + IW * 4); | load_vec<3>(src_v[1], input + IW * 4); | ||||
@@ -899,9 +909,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t init = GiZeroFloat32(); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
init = vld1q_f32(bias); | |||||
init = GiLoadFloat32(bias); | |||||
} | } | ||||
constexpr size_t stride = 2; | constexpr size_t stride = 2; | ||||
size_t oh_start = (PH + stride - 1) / stride; | size_t oh_start = (PH + stride - 1) / stride; | ||||
@@ -911,7 +921,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||||
if (PH || PW) { | if (PH || PW) { | ||||
PaddingCompute<bias_mode, Op>::compute( | PaddingCompute<bias_mode, Op>::compute( | ||||
src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW, | src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW, | ||||
reinterpret_cast<const float32x4_t*>(filter), init); | |||||
reinterpret_cast<const GI_FLOAT32_t*>(filter), init); | |||||
} | } | ||||
size_t oh = oh_start; | size_t oh = oh_start; | ||||
for (; oh + 1 < oh_end; oh += 2) { | for (; oh + 1 < oh_end; oh += 2) { | ||||
@@ -921,13 +931,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||||
size_t iw = ow * stride - PW; | size_t iw = ow * stride - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2][2]; | |||||
GI_FLOAT32_t dst_v[2][2]; | |||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 2>::impl( | load_bias_vec<bias_mode, 2>::impl( | ||||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t kernel[3][5]; | |||||
float32x4_t src_v[2][7]; | |||||
GI_FLOAT32_t kernel[3][5]; | |||||
GI_FLOAT32_t src_v[2][7]; | |||||
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | #define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | ||||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | load_vec<5>(kernel0, filter + i * 5 * 4); \ | ||||
load_vec<7>(src, input + i * IW * 4); \ | load_vec<7>(src, input + i * IW * 4); \ | ||||
@@ -965,13 +975,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||||
size_t iw = ow * stride - PW; | size_t iw = ow * stride - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v[2]; | |||||
GI_FLOAT32_t dst_v[2]; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | &dst_v[0], init, bias + oh * OW * 4 + ow * 4); | ||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | ||||
float32x4_t kernel[3][5]; | |||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t kernel[3][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | #define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | ||||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | load_vec<5>(kernel0, filter + i * 5 * 4); \ | ||||
load_vec<5>(src, input + i * IW * 4); \ | load_vec<5>(src, input + i * IW * 4); \ | ||||
@@ -1010,11 +1020,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||||
size_t iw = ow * stride - PW; | size_t iw = ow * stride - PW; | ||||
const float* input = src + ih * IW * 4 + iw * 4; | const float* input = src + ih * IW * 4 + iw * 4; | ||||
float* output = dst + oh * OW * 4 + ow * 4; | float* output = dst + oh * OW * 4 + ow * 4; | ||||
float32x4_t dst_v; | |||||
GI_FLOAT32_t dst_v; | |||||
load_bias_vec<bias_mode, 1>::impl( | load_bias_vec<bias_mode, 1>::impl( | ||||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | &dst_v, init, bias + oh * OW * 4 + ow * 4); | ||||
float32x4_t kernel[2][5]; | |||||
float32x4_t src_v[2][5]; | |||||
GI_FLOAT32_t kernel[2][5]; | |||||
GI_FLOAT32_t src_v[2][5]; | |||||
#define COMPUTE_5X5_1(i, dst, src, kernel) \ | #define COMPUTE_5X5_1(i, dst, src, kernel) \ | ||||
load_vec<5>(kernel, filter + i * 5 * 4); \ | load_vec<5>(kernel, filter + i * 5 * 4); \ | ||||
load_vec<6>(src, input + i * IW * 4); \ | load_vec<6>(src, input + i * IW * 4); \ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,11 +12,11 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace channel_wise_nchw44_float { | namespace channel_wise_nchw44_float { | ||||
#define KERN(stride, i) \ | #define KERN(stride, i) \ | ||||
@@ -37,7 +37,7 @@ KERN(stride2, 5) | |||||
#undef KERN | #undef KERN | ||||
} // namespace channel_wise_nchw44_float | } // namespace channel_wise_nchw44_float | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,18 +9,18 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct.h" | |||||
#include <cstring> | #include <cstring> | ||||
#include "include/megdnn/oprs.h" | #include "include/megdnn/oprs.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
MIDOUT_DECL(megdnn_arm_conv_f32) | |||||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
MIDOUT_DECL(megdnn_gi_conv_f32) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
using namespace fp32; | using namespace fp32; | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -34,65 +34,65 @@ struct do_pixel_proxy { | |||||
const int ow); | const int ow); | ||||
}; | }; | ||||
#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); | |||||
#define LOAD_OUT \ | |||||
if (width < 4) { \ | |||||
auto load_less_4 = [](float* dst, float32x4_t& data) { \ | |||||
if (width == 1u) { \ | |||||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||||
} else if (width == 2u) { \ | |||||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||||
} else if (width == 3u) { \ | |||||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
load_less_4(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
load_less_4(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
load_less_4(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
load_less_4(dst + 3 * OW, out3); \ | |||||
} else { \ | |||||
if (height > 0) \ | |||||
out0 = vld1q_f32(dst + 0 * OW); \ | |||||
if (height > 1) \ | |||||
out1 = vld1q_f32(dst + 1 * OW); \ | |||||
if (height > 2) \ | |||||
out2 = vld1q_f32(dst + 2 * OW); \ | |||||
if (height > 3) \ | |||||
out3 = vld1q_f32(dst + 3 * OW); \ | |||||
} | |||||
#define cb_store(i) vst1q_lane_f32(dst + i, data, i); | |||||
#define STORE_OUT \ | |||||
#define cb_load(i) data = GiLd1qLaneFloat32(dst + i, data, i); | |||||
#define LOAD_OUT \ | |||||
if (width < 4) { \ | if (width < 4) { \ | ||||
auto store_less_4 = [](float* dst, float32x4_t& data) { \ | |||||
auto load_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ | |||||
if (width == 1u) { \ | if (width == 1u) { \ | ||||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||||
} else if (width == 2u) { \ | } else if (width == 2u) { \ | ||||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||||
} else if (width == 3u) { \ | } else if (width == 3u) { \ | ||||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||||
} \ | } \ | ||||
}; \ | }; \ | ||||
if (height >= 1) \ | if (height >= 1) \ | ||||
store_less_4(dst + 0 * OW, out0); \ | |||||
load_less_4(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | if (height >= 2) \ | ||||
store_less_4(dst + 1 * OW, out1); \ | |||||
load_less_4(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | if (height >= 3) \ | ||||
store_less_4(dst + 2 * OW, out2); \ | |||||
load_less_4(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | if (height >= 4) \ | ||||
store_less_4(dst + 3 * OW, out3); \ | |||||
load_less_4(dst + 3 * OW, out3); \ | |||||
} else { \ | } else { \ | ||||
if (height >= 1) \ | |||||
vst1q_f32(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
vst1q_f32(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
vst1q_f32(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
vst1q_f32(dst + 3 * OW, out3); \ | |||||
if (height > 0) \ | |||||
out0 = GiLoadFloat32(dst + 0 * OW); \ | |||||
if (height > 1) \ | |||||
out1 = GiLoadFloat32(dst + 1 * OW); \ | |||||
if (height > 2) \ | |||||
out2 = GiLoadFloat32(dst + 2 * OW); \ | |||||
if (height > 3) \ | |||||
out3 = GiLoadFloat32(dst + 3 * OW); \ | |||||
} | |||||
#define cb_store(i) GiStoreLane##i##Float32(dst + i, data); | |||||
#define STORE_OUT \ | |||||
if (width < 4) { \ | |||||
auto store_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ | |||||
if (width == 1u) { \ | |||||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||||
} else if (width == 2u) { \ | |||||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||||
} else if (width == 3u) { \ | |||||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||||
} \ | |||||
}; \ | |||||
if (height >= 1) \ | |||||
store_less_4(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
store_less_4(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
store_less_4(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
store_less_4(dst + 3 * OW, out3); \ | |||||
} else { \ | |||||
if (height >= 1) \ | |||||
GiStoreFloat32(dst + 0 * OW, out0); \ | |||||
if (height >= 2) \ | |||||
GiStoreFloat32(dst + 1 * OW, out1); \ | |||||
if (height >= 3) \ | |||||
GiStoreFloat32(dst + 2 * OW, out2); \ | |||||
if (height >= 4) \ | |||||
GiStoreFloat32(dst + 3 * OW, out3); \ | |||||
} | } | ||||
template <int height, int width> | template <int height, int width> | ||||
@@ -104,33 +104,33 @@ struct do_pixel_proxy<1, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -145,45 +145,45 @@ struct do_pixel_proxy<2, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -198,57 +198,57 @@ struct do_pixel_proxy<3, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -263,69 +263,69 @@ struct do_pixel_proxy<4, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -340,81 +340,81 @@ struct do_pixel_proxy<5, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; | |||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -429,94 +429,94 @@ struct do_pixel_proxy<6, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||||
inp; | inp; | ||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||||
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr5); | |||||
out0 = GiMlaqFloat32(out0, inp, kr5); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr5); | |||||
out1 = GiMlaqFloat32(out1, inp, kr5); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr5); | |||||
out2 = GiMlaqFloat32(out2, inp, kr5); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 8 * IW); | |||||
inp = GiLoadFloat32(src_dd + 8 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr5); | |||||
out3 = GiMlaqFloat32(out3, inp, kr5); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -531,106 +531,106 @@ struct do_pixel_proxy<7, height, width> { | |||||
(void)IH; | (void)IH; | ||||
(void)OH; | (void)OH; | ||||
const int ih = oh, iw = ow; | const int ih = oh, iw = ow; | ||||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||||
kr6, inp; | kr6, inp; | ||||
src += ih * IW + iw; | src += ih * IW + iw; | ||||
dst += oh * OW + ow; | dst += oh * OW + ow; | ||||
LOAD_OUT; | LOAD_OUT; | ||||
for (int fw = 0; fw < FW; ++fw) { | for (int fw = 0; fw < FW; ++fw) { | ||||
const float* src_dd = src + fw; | const float* src_dd = src + fw; | ||||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||||
kr6 = vdupq_n_f32(filter[6 * FW + fw]); | |||||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||||
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); | |||||
kr6 = GiBroadcastFloat32(filter[6 * FW + fw]); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 0 * IW); | |||||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr0); | |||||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 1 * IW); | |||||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr1); | |||||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr0); | |||||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 2 * IW); | |||||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr2); | |||||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr1); | |||||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr0); | |||||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 3 * IW); | |||||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr3); | |||||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr2); | |||||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr1); | |||||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr0); | |||||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 4 * IW); | |||||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr4); | |||||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr3); | |||||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr2); | |||||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr1); | |||||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 5 * IW); | |||||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr5); | |||||
out0 = GiMlaqFloat32(out0, inp, kr5); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr4); | |||||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr3); | |||||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr2); | |||||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||||
if (height > 0) | if (height > 0) | ||||
inp = vld1q_f32(src_dd + 6 * IW); | |||||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||||
if (height > 0) | if (height > 0) | ||||
out0 = vmlaq_f32(out0, inp, kr6); | |||||
out0 = GiMlaqFloat32(out0, inp, kr6); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr5); | |||||
out1 = GiMlaqFloat32(out1, inp, kr5); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr4); | |||||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr3); | |||||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||||
if (height > 1) | if (height > 1) | ||||
inp = vld1q_f32(src_dd + 7 * IW); | |||||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||||
if (height > 1) | if (height > 1) | ||||
out1 = vmlaq_f32(out1, inp, kr6); | |||||
out1 = GiMlaqFloat32(out1, inp, kr6); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr5); | |||||
out2 = GiMlaqFloat32(out2, inp, kr5); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr4); | |||||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||||
if (height > 2) | if (height > 2) | ||||
inp = vld1q_f32(src_dd + 8 * IW); | |||||
inp = GiLoadFloat32(src_dd + 8 * IW); | |||||
if (height > 2) | if (height > 2) | ||||
out2 = vmlaq_f32(out2, inp, kr6); | |||||
out2 = GiMlaqFloat32(out2, inp, kr6); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr5); | |||||
out3 = GiMlaqFloat32(out3, inp, kr5); | |||||
if (height > 3) | if (height > 3) | ||||
inp = vld1q_f32(src_dd + 9 * IW); | |||||
inp = GiLoadFloat32(src_dd + 9 * IW); | |||||
if (height > 3) | if (height > 3) | ||||
out3 = vmlaq_f32(out3, inp, kr6); | |||||
out3 = GiMlaqFloat32(out3, inp, kr6); | |||||
} | } | ||||
STORE_OUT; | STORE_OUT; | ||||
} | } | ||||
@@ -836,31 +836,31 @@ void conv_bias::kern_direct( | |||||
} while (0) | } while (0) | ||||
switch (FH) { | switch (FH) { | ||||
case 1: | case 1: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 2: | case 2: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 3: | case 3: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 4: | case 4: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 5: | case 5: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 6: | case 6: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 7: | case 7: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
} | } | ||||
@@ -872,31 +872,31 @@ void conv_bias::kern_direct( | |||||
} while (0) | } while (0) | ||||
switch (FH) { | switch (FH) { | ||||
case 1: | case 1: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 2: | case 2: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 3: | case 3: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 4: | case 4: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 5: | case 5: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 6: | case 6: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
case 7: | case 7: | ||||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
break; | break; | ||||
} | } |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/direct.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -13,7 +13,7 @@ | |||||
#include <cstddef> | #include <cstddef> | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace fp32 { | namespace fp32 { | ||||
namespace conv_bias { | namespace conv_bias { | ||||
@@ -23,7 +23,7 @@ void kern_direct( | |||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace fp32 | } // namespace fp32 | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,12 +11,12 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace conv_bias { | namespace conv_bias { | ||||
template <> | template <> | ||||
void pack_src_fp32_nchw44<1>( | void pack_src_fp32_nchw44<1>( | ||||
@@ -51,23 +51,23 @@ static inline void odd_even_split_iw8_even( | |||||
const int src_offset = src_idx * ic_step; | const int src_offset = src_idx * ic_step; | ||||
const int even_offset = iw_idx / 2 * ic_step; | const int even_offset = iw_idx / 2 * ic_step; | ||||
const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | ||||
float32x4_t temp[8]; | |||||
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||||
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||||
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||||
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||||
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||||
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||||
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||||
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||||
GI_FLOAT32_t temp[8]; | |||||
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); | |||||
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); | |||||
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); | |||||
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); | |||||
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); | |||||
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); | |||||
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); | |||||
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); | |||||
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||||
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||||
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||||
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||||
} | } | ||||
static inline void odd_even_split_iw8_odd( | static inline void odd_even_split_iw8_odd( | ||||
@@ -77,23 +77,23 @@ static inline void odd_even_split_iw8_odd( | |||||
const int src_offset = src_idx * ic_step; | const int src_offset = src_idx * ic_step; | ||||
const int even_offset = (iw_idx + 1) / 2 * ic_step; | const int even_offset = (iw_idx + 1) / 2 * ic_step; | ||||
const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | ||||
float32x4_t temp[8]; | |||||
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||||
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||||
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||||
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||||
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||||
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||||
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||||
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||||
GI_FLOAT32_t temp[8]; | |||||
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); | |||||
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); | |||||
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); | |||||
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); | |||||
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); | |||||
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); | |||||
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); | |||||
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); | |||||
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||||
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||||
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||||
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||||
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||||
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -104,7 +104,7 @@ void pack_src_fp32_nchw44<2>( | |||||
const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { | const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { | ||||
constexpr int ic_step = 4; | constexpr int ic_step = 4; | ||||
int odd_start = megdnn::div_ceil(iw2, 2); | int odd_start = megdnn::div_ceil(iw2, 2); | ||||
float32x4_t zero_v = vdupq_n_f32(0.f); | |||||
GI_FLOAT32_t zero_v = GiZeroFloat32(); | |||||
MEGDNN_MARK_USED_VAR(ph); | MEGDNN_MARK_USED_VAR(ph); | ||||
bool even_start = pw % 2 == 0; | bool even_start = pw % 2 == 0; | ||||
rep_step(ic_idx, ic, ic_step) { | rep_step(ic_idx, ic, ic_step) { | ||||
@@ -115,9 +115,10 @@ void pack_src_fp32_nchw44<2>( | |||||
int iw_idx = 0; | int iw_idx = 0; | ||||
rep(idx, pw) { | rep(idx, pw) { | ||||
if (iw_idx % 2 == 0) { | if (iw_idx % 2 == 0) { | ||||
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
} else { | } else { | ||||
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||||
GiStoreFloat32( | |||||
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||||
} | } | ||||
++iw_idx; | ++iw_idx; | ||||
} | } | ||||
@@ -136,21 +137,22 @@ void pack_src_fp32_nchw44<2>( | |||||
} | } | ||||
for (; src_idx < iw; ++src_idx) { | for (; src_idx < iw; ++src_idx) { | ||||
if (iw_idx % 2 == 0) { | if (iw_idx % 2 == 0) { | ||||
vst1q_f32( | |||||
GiStoreFloat32( | |||||
sptr_base + iw_idx / 2 * ic_step, | sptr_base + iw_idx / 2 * ic_step, | ||||
vld1q_f32(sptr + src_idx * ic_step)); | |||||
GiLoadFloat32(sptr + src_idx * ic_step)); | |||||
} else { | } else { | ||||
vst1q_f32( | |||||
GiStoreFloat32( | |||||
sptr_base + (odd_start + iw_idx / 2) * ic_step, | sptr_base + (odd_start + iw_idx / 2) * ic_step, | ||||
vld1q_f32(sptr + src_idx * ic_step)); | |||||
GiLoadFloat32(sptr + src_idx * ic_step)); | |||||
} | } | ||||
++iw_idx; | ++iw_idx; | ||||
} | } | ||||
rep(idx, pad_right) { | rep(idx, pad_right) { | ||||
if (iw_idx % 2 == 0) { | if (iw_idx % 2 == 0) { | ||||
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
} else { | } else { | ||||
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||||
GiStoreFloat32( | |||||
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||||
} | } | ||||
++iw_idx; | ++iw_idx; | ||||
} | } | ||||
@@ -163,7 +165,7 @@ void pack_src_fp32_nchw44<2>( | |||||
} | } | ||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BIAS(2); | INSTANTIATION_CONV_S1_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); | INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_NO_BIAS(2); | INSTANTIATION_CONV_S1_NO_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BIAS(2); | INSTANTIATION_CONV_S2_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); | INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_NO_BIAS(2); | INSTANTIATION_CONV_S2_NO_BIAS(2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BIAS(3); | INSTANTIATION_CONV_S1_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); | INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_NO_BIAS(3); | INSTANTIATION_CONV_S1_NO_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BIAS(3); | INSTANTIATION_CONV_S2_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); | INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_NO_BIAS(3); | INSTANTIATION_CONV_S2_NO_BIAS(3); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BIAS(5); | INSTANTIATION_CONV_S1_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); | INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_NO_BIAS(5); | INSTANTIATION_CONV_S1_NO_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BIAS(5); | INSTANTIATION_CONV_S2_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); | INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_NO_BIAS(5); | INSTANTIATION_CONV_S2_NO_BIAS(5); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BIAS(7); | INSTANTIATION_CONV_S1_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); | INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
INSTANTIATION_CONV_S1_NO_BIAS(7); | INSTANTIATION_CONV_S1_NO_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BIAS(7); | INSTANTIATION_CONV_S2_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); | INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
INSTANTIATION_CONV_S2_NO_BIAS(7); | INSTANTIATION_CONV_S2_NO_BIAS(7); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,16 +12,15 @@ | |||||
*/ | */ | ||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
template < | template < | ||||
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||||
}; | }; | ||||
#define cb2(step, lane, ow_block) \ | #define cb2(step, lane, ow_block) \ | ||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | ||||
c[1][step] = vfmaq_laneq_f32( \ | |||||
c[1][step] = GiSimdFmaLane( \ | |||||
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | ||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | ||||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | #define SHIFT_CAL_HELPER(ow_block, remain_w) \ | ||||
@@ -122,7 +121,7 @@ public: | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs1Nchw44FP32 { | |||||
struct KerGiXXs1Nchw44FP32 { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -130,7 +129,7 @@ struct KerNeonXXs1Nchw44FP32 { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -147,20 +146,20 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -172,7 +171,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -189,24 +188,24 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -217,7 +216,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
} | } | ||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -234,36 +233,36 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | |||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); | |||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | |||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); | |||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -275,7 +274,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -292,46 +291,46 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | |||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); | |||||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | |||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); | |||||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | |||||
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[4] = GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step); | |||||
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | |||||
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[5] = GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step); | |||||
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -352,10 +351,10 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
constexpr int fh = filter_size; | constexpr int fh = filter_size; | ||||
constexpr int fw = filter_size; | constexpr int fw = filter_size; | ||||
constexpr int ic_step = 4; | constexpr int ic_step = 4; | ||||
#if MEGDNN_ARMV7 | |||||
constexpr int big_oc_step = 4; | |||||
#else | |||||
#if MEGDNN_AARCH64 | |||||
constexpr int big_oc_step = 8; | constexpr int big_oc_step = 8; | ||||
#else | |||||
constexpr int big_oc_step = 4; | |||||
#endif | #endif | ||||
constexpr int oc_step = 4; | constexpr int oc_step = 4; | ||||
constexpr int ih_step = 1; | constexpr int ih_step = 1; | ||||
@@ -381,9 +380,9 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
switch (ow_remain) { | switch (ow_remain) { | ||||
#define cb(step) \ | #define cb(step) \ | ||||
case step: \ | case step: \ | ||||
kern_big_oc_remain = KerNeonXXs1Nchw44FP32< \ | |||||
kern_big_oc_remain = KerGiXXs1Nchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | ||||
kern_small_oc_remain = KerNeonXXs1Nchw44FP32< \ | |||||
kern_small_oc_remain = KerGiXXs1Nchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | ||||
break; | break; | ||||
@@ -402,7 +401,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
const int bias_offset = | const int bias_offset = | ||||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | ||||
KerNeonXXs1Nchw44FP32< | |||||
KerGiXXs1Nchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | bias + bias_offset, dst + dst_offset, ic, ih, iw, | ||||
@@ -434,7 +433,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
const int bias_offset = | const int bias_offset = | ||||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | ||||
KerNeonXXs1Nchw44FP32< | |||||
KerGiXXs1Nchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | bias + bias_offset, dst + dst_offset, ic, ih, iw, |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,16 +12,15 @@ | |||||
*/ | */ | ||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
template < | template < | ||||
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||||
}; | }; | ||||
#define cb2(step, lane, ow_block) \ | #define cb2(step, lane, ow_block) \ | ||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | ||||
c[1][step] = vfmaq_laneq_f32( \ | |||||
c[1][step] = GiSimdFmaLane( \ | |||||
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | ||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
#define cb(step, lane, ow_block) \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | ||||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | #define SHIFT_CAL_HELPER(ow_block, remain_w) \ | ||||
@@ -122,7 +121,7 @@ public: | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs2Nchw44FP32 { | |||||
struct KerGiXXs2Nchw44FP32 { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -130,7 +129,7 @@ struct KerNeonXXs2Nchw44FP32 { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -147,36 +146,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][4]; | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][4]; | |||||
/////////row 0///////////// | /////////row 0///////////// | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
/////////row 1///////////// | /////////row 1///////////// | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -188,7 +187,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -205,62 +204,62 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | ||||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][4]; | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][4]; | |||||
/////////row 0///////////// | /////////row 0///////////// | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
/////////row 1///////////// | /////////row 1///////////// | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
//////////row 2///////////// | //////////row 2///////////// | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
@@ -272,7 +271,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -289,7 +288,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
@@ -297,28 +296,28 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][4]; | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][4]; | |||||
// even element | // even element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | |||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); | |||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
// odd element | // odd element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | |||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); | |||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
@@ -337,7 +336,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
* src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 | * src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 | ||||
**/ | **/ | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | ||||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | const float32_t* src_ptr_origin, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -354,7 +353,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
const int ld_src_iw = iw * oc_step; | const int ld_src_iw = iw * oc_step; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][ow_block]; | |||||
GI_FLOAT32_t c[c_dim][ow_block]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
@@ -362,36 +361,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | ||||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | ||||
float32x4_t src[ow_block]; | |||||
float32x4_t weight[c_dim][4]; | |||||
GI_FLOAT32_t src[ow_block]; | |||||
GI_FLOAT32_t weight[c_dim][4]; | |||||
// even element | // even element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | |||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); | |||||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | |||||
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len); | |||||
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
// odd element | // odd element | ||||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | |||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); | |||||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | |||||
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||||
src[1] = GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len); | |||||
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | ||||
@@ -414,10 +413,10 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
constexpr int fh = filter_size; | constexpr int fh = filter_size; | ||||
constexpr int fw = filter_size; | constexpr int fw = filter_size; | ||||
constexpr int ic_step = 4; | constexpr int ic_step = 4; | ||||
#if MEGDNN_ARMV7 | |||||
constexpr int big_oc_step = 4; | |||||
#else | |||||
#if MEGDNN_AARCH64 | |||||
constexpr int big_oc_step = 8; | constexpr int big_oc_step = 8; | ||||
#else | |||||
constexpr int big_oc_step = 4; | |||||
#endif | #endif | ||||
constexpr int oc_step = 4; | constexpr int oc_step = 4; | ||||
constexpr int ih_step = 1; | constexpr int ih_step = 1; | ||||
@@ -444,9 +443,9 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
switch (ow_remain) { | switch (ow_remain) { | ||||
#define cb(step) \ | #define cb(step) \ | ||||
case step: \ | case step: \ | ||||
kern_big_oc_remain = KerNeonXXs2Nchw44FP32< \ | |||||
kern_big_oc_remain = KerGiXXs2Nchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | ||||
kern_small_oc_remain = KerNeonXXs2Nchw44FP32< \ | |||||
kern_small_oc_remain = KerGiXXs2Nchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | ||||
break; | break; | ||||
@@ -469,7 +468,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
const int bias_offset = | const int bias_offset = | ||||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | ||||
KerNeonXXs2Nchw44FP32< | |||||
KerGiXXs2Nchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | bias + bias_offset, dst + dst_offset, ic, ih, iw, | ||||
@@ -510,7 +509,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
const int bias_offset = | const int bias_offset = | ||||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | ||||
KerNeonXXs2Nchw44FP32< | |||||
KerGiXXs2Nchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | bias + bias_offset, dst + dst_offset, ic, ih, iw, |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(2, 1); | INSTANCE_CONV_BIAS(2, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(2, 1); | INSTANCE_CONV_NO_BIAS(2, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(2, 2); | INSTANCE_CONV_BIAS(2, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(2, 2); | INSTANCE_CONV_NO_BIAS(2, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(3, 1); | INSTANCE_CONV_BIAS(3, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(3, 1); | INSTANCE_CONV_NO_BIAS(3, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(3, 2); | INSTANCE_CONV_BIAS(3, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(3, 2); | INSTANCE_CONV_NO_BIAS(3, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(5, 1); | INSTANCE_CONV_BIAS(5, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(5, 1); | INSTANCE_CONV_NO_BIAS(5, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(5, 2); | INSTANCE_CONV_BIAS(5, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(5, 2); | INSTANCE_CONV_NO_BIAS(5, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(7, 1); | INSTANCE_CONV_BIAS(7, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(7, 1); | INSTANCE_CONV_NO_BIAS(7, 1); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BIAS(7, 2); | INSTANCE_CONV_BIAS(7, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); | INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | |||||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,6 +10,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
INSTANCE_CONV_NO_BIAS(7, 2); | INSTANCE_CONV_NO_BIAS(7, 2); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,20 +11,19 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#if MEGDNN_ARMV7 | #if MEGDNN_ARMV7 | ||||
#include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
#endif | #endif | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
/** | /** | ||||
@@ -50,15 +49,15 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> { | |||||
}; | }; | ||||
#define cb(step) \ | #define cb(step) \ | ||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | ||||
(step * stride + src_idx) % 4); \ | (step * stride + src_idx) % 4); \ | ||||
c[1][step] = vfmaq_laneq_f32( \ | |||||
c[1][step] = GiSimdFmaLane( \ | |||||
c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \ | c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \ | ||||
(step * stride + src_idx) % 4); | (step * stride + src_idx) % 4); | ||||
#define cb2(step) \ | #define cb2(step) \ | ||||
c[0][step] = vfmaq_laneq_f32( \ | |||||
c[0][step] = GiSimdFmaLane( \ | |||||
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | ||||
(step * stride + src_idx) % 4); | (step * stride + src_idx) % 4); | ||||
@@ -127,7 +126,7 @@ public: | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | ||||
int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG> | int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG> | ||||
struct KerNeonXXs2NchwNchw44FP32 { | |||||
struct KerGiXXs2NchwNchw44FP32 { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -136,8 +135,7 @@ struct KerNeonXXs2NchwNchw44FP32 { | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs2NchwNchw44FP32< | |||||
bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -154,16 +152,16 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
const int ld_weight_ic = oc_step * filter_size * filter_size; | const int ld_weight_ic = oc_step * filter_size * filter_size; | ||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | |||||
GI_FLOAT32_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | |||||
float32x4_t weight[c_dim][filter_size]; | |||||
GI_FLOAT32_t src[src_reg_size]; | |||||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||||
#define KERNEL_CB(step) \ | #define KERNEL_CB(step) \ | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \ | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ | |||||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | ||||
@@ -186,8 +184,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs2NchwNchw44FP32< | |||||
bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -204,16 +201,16 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
const int ld_weight_ic = oc_step * filter_size * filter_size; | const int ld_weight_ic = oc_step * filter_size * filter_size; | ||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | |||||
GI_FLOAT32_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | |||||
float32x4_t weight[c_dim][filter_size]; | |||||
GI_FLOAT32_t src[src_reg_size]; | |||||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||||
#define KERNEL_CB(step) \ | #define KERNEL_CB(step) \ | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \ | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ | |||||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | ||||
@@ -233,8 +230,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs2NchwNchw44FP32< | |||||
bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -251,32 +247,32 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
const int ld_weight_ic = oc_step * filter_size * filter_size; | const int ld_weight_ic = oc_step * filter_size * filter_size; | ||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | |||||
GI_FLOAT32_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | |||||
float32x4_t weight[c_dim][filter_size]; | |||||
GI_FLOAT32_t src[src_reg_size]; | |||||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||||
// row 0 | // row 0 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | ||||
// row 1 | // row 1 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | ||||
// row 2 | // row 2 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>( | |||||
src, src_ptr + 2 * iw, 0); | src, src_ptr + 2 * iw, 0); | ||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | ||||
@@ -292,7 +288,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
#if MEGDNN_ARMV7 | #if MEGDNN_ARMV7 | ||||
template <BiasMode bias_mode, typename Op> | template <BiasMode bias_mode, typename Op> | ||||
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -310,7 +306,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||||
const int ld_src_ic_skip_bytes = | const int ld_src_ic_skip_bytes = | ||||
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[1][8]; | |||||
GI_FLOAT32_t c[1][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | ||||
const int img_stride = ih * iw; | const int img_stride = ih * iw; | ||||
constexpr int filter_stride = filter_size * filter_size * oc_step; | constexpr int filter_stride = filter_size * filter_size * oc_step; | ||||
@@ -464,8 +460,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op> | template <BiasMode bias_mode, typename Op> | ||||
struct KerNeonXXs2NchwNchw44FP32< | |||||
bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -483,7 +478,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
const int ld_src_ic_skip_bytes = | const int ld_src_ic_skip_bytes = | ||||
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[1][8]; | |||||
GI_FLOAT32_t c[1][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | ||||
/** | /** | ||||
* c q8-q15 | * c q8-q15 | ||||
@@ -626,8 +621,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
template < | template < | ||||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonXXs2NchwNchw44FP32< | |||||
bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { | |||||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { | |||||
static void impl( | static void impl( | ||||
const float32_t* src_ptr, const float32_t* weight_ptr, | const float32_t* src_ptr, const float32_t* weight_ptr, | ||||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | ||||
@@ -644,22 +638,22 @@ struct KerNeonXXs2NchwNchw44FP32< | |||||
const int ld_weight_ic = oc_step * filter_size * filter_size; | const int ld_weight_ic = oc_step * filter_size * filter_size; | ||||
const int ld_src_ic = ih * iw; | const int ld_src_ic = ih * iw; | ||||
constexpr int c_dim = OCHelper<oc_block>::val; | constexpr int c_dim = OCHelper<oc_block>::val; | ||||
float32x4_t c[c_dim][8]; | |||||
GI_FLOAT32_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | ||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | ||||
float32x4_t src[src_reg_size]; | |||||
float32x4_t weight[c_dim][filter_size]; | |||||
GI_FLOAT32_t src[src_reg_size]; | |||||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||||
// row 0 | // row 0 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | ||||
// row 1 | // row 1 | ||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); | |||||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | ||||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | ||||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | ||||
@@ -711,9 +705,9 @@ struct ConvDirectFp32NchwNchw44 { | |||||
switch (ow_remain) { | switch (ow_remain) { | ||||
#define cb(step) \ | #define cb(step) \ | ||||
case step: \ | case step: \ | ||||
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||||
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | ||||
kern_small_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||||
kern_small_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \ | bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \ | ||||
break; | break; | ||||
@@ -731,7 +725,7 @@ struct ConvDirectFp32NchwNchw44 { | |||||
ic_step * pack_iw_len; | ic_step * pack_iw_len; | ||||
const int dst_offset = | const int dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
KerNeonXXs2NchwNchw44FP32< | |||||
KerGiXXs2NchwNchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, big_oc_step, stride, | bias_mode, Op, ow_step, filter_size, big_oc_step, stride, | ||||
ow_step>:: | ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
@@ -760,7 +754,7 @@ struct ConvDirectFp32NchwNchw44 { | |||||
ic_step * pack_iw_len; | ic_step * pack_iw_len; | ||||
const int dst_offset = | const int dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
KerNeonXXs2NchwNchw44FP32< | |||||
KerGiXXs2NchwNchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, oc_step, stride, | bias_mode, Op, ow_step, filter_size, oc_step, stride, | ||||
ow_step>:: | ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
@@ -819,7 +813,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||||
switch (ow_remain) { | switch (ow_remain) { | ||||
#define cb(step) \ | #define cb(step) \ | ||||
case step: \ | case step: \ | ||||
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||||
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||||
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | ||||
break; | break; | ||||
@@ -849,7 +843,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||||
ic_step * pack_iw_len; | ic_step * pack_iw_len; | ||||
const int dst_offset = | const int dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
KerNeonXXs2NchwNchw44FP32< | |||||
KerGiXXs2NchwNchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, big_oc_step, | bias_mode, Op, ow_step, filter_size, big_oc_step, | ||||
stride, ow_step, CpuTag::A7_TAG>:: | stride, ow_step, CpuTag::A7_TAG>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, | ||||
@@ -878,7 +872,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||||
ic_step * pack_iw_len; | ic_step * pack_iw_len; | ||||
const int dst_offset = | const int dst_offset = | ||||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | ||||
KerNeonXXs2NchwNchw44FP32< | |||||
KerGiXXs2NchwNchw44FP32< | |||||
bias_mode, Op, ow_step, filter_size, big_oc_step, | bias_mode, Op, ow_step, filter_size, big_oc_step, | ||||
stride, ow_step>:: | stride, ow_step>:: | ||||
impl(src + src_offset, filter + weight_offset, | impl(src + src_offset, filter + weight_offset, |
@@ -0,0 +1,723 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" | |||||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs1) | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace fp32; | |||||
using namespace conv_stride1; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride1::do_conv_2x2_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
//! unroll of 2 | |||||
size_t ic = 0; | |||||
for (; ic + 1 < IC; ic += 2) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
const float* src_ptr1 = src_ptr + IW * IH; | |||||
float* outptr = dst; | |||||
const float* r00 = src_ptr; | |||||
const float* r01 = src_ptr + IW; | |||||
const float* r10 = src_ptr1; | |||||
const float* r11 = src_ptr1 + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
const float* k1 = k0 + 4; | |||||
GI_FLOAT32_t _k0 = GiLoadFloat32(k0); | |||||
GI_FLOAT32_t _k1 = GiLoadFloat32(k1); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _r000 = GiLoadFloat32(r00); | |||||
GI_FLOAT32_t _r010 = GiLoadFloat32(r01); | |||||
GI_FLOAT32_t _r001 = GiLoadFloat32(r00 + 1); | |||||
GI_FLOAT32_t _r011 = GiLoadFloat32(r01 + 1); | |||||
GI_FLOAT32_t _r100 = GiLoadFloat32(r10); | |||||
GI_FLOAT32_t _r110 = GiLoadFloat32(r11); | |||||
GI_FLOAT32_t _r101 = GiLoadFloat32(r10 + 1); | |||||
GI_FLOAT32_t _r111 = GiLoadFloat32(r11 + 1); | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r000, _k0, 0); | |||||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r001, _k0, 1); | |||||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r010, _k0, 0); | |||||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r011, _k0, 1); | |||||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r100, _k1, 0); | |||||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r101, _k1, 1); | |||||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r110, _k1, 0); | |||||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r111, _k1, 1); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r00 += 4; | |||||
r01 += 4; | |||||
r10 += 4; | |||||
r11 += 4; | |||||
outptr += 4; | |||||
} | |||||
r00 += tail_step; | |||||
r01 += tail_step; | |||||
r10 += tail_step; | |||||
r11 += tail_step; | |||||
} | |||||
} | |||||
for (; ic < IC; ic++) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter + ic * 4; | |||||
GI_FLOAT32_t _k0 = GiBroadcastFloat32(k0[0]); | |||||
GI_FLOAT32_t _k1 = GiBroadcastFloat32(k0[1]); | |||||
GI_FLOAT32_t _k2 = GiBroadcastFloat32(k0[2]); | |||||
GI_FLOAT32_t _k3 = GiBroadcastFloat32(k0[3]); | |||||
rep(h, OH) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r01 = GiLoadFloat32(r0 + 1); | |||||
GI_FLOAT32_t _r11 = GiLoadFloat32(r1 + 1); | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _sum2; | |||||
_sum = GiMlaqFloat32(_sum, _r00, _k0); | |||||
_sum2 = GiMultiplyFloat32(_r01, _k1); | |||||
_sum = GiMlaqFloat32(_sum, _r10, _k2); | |||||
_sum2 = GiMlaqFloat32(_sum2, _r11, _k3); | |||||
_sum = GiAddFloat32(_sum, _sum2); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_3x3_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||||
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); | |||||
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); | |||||
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); | |||||
GI_FLOAT32_t _sum3 = GiLoadFloat32(outptr2); | |||||
GI_FLOAT32_t _sum4 = GiBroadcastFloat32(0.f); | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); | |||||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); | |||||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); | |||||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); | |||||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); | |||||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||||
GI_FLOAT32_t _r30n = GiLoadFloat32LowHalf(r3 + 4); | |||||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r30n, 1); | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r30n, 2); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); | |||||
_sum3 = GiSimdFmaLane(_sum3, _r10, _k0123, 0); | |||||
_sum4 = GiSimdFmaLane(_sum4, _r11, _k0123, 1); | |||||
_sum3 = GiSimdFmaLane(_sum3, _r12, _k0123, 2); | |||||
_sum4 = GiSimdFmaLane(_sum4, _r20, _k3456, 0); | |||||
_sum3 = GiSimdFmaLane(_sum3, _r21, _k3456, 1); | |||||
_sum4 = GiSimdFmaLane(_sum4, _r22, _k3456, 2); | |||||
_sum3 = GiSimdFmaLane(_sum3, _r30, _k6789, 0); | |||||
_sum4 = GiSimdFmaLane(_sum4, _r31, _k6789, 1); | |||||
_sum3 = GiSimdFmaLane(_sum3, _r32, _k6789, 2); | |||||
_sum1 = GiAddFloat32(_sum1, _sum2); | |||||
_sum3 = GiAddFloat32(_sum3, _sum4); | |||||
GiStoreFloat32(outptr, _sum1); | |||||
GiStoreFloat32(outptr2, _sum3); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); | |||||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); | |||||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); | |||||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); | |||||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); | |||||
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); | |||||
_sum1 = GiAddFloat32(_sum1, _sum2); | |||||
GiStoreFloat32(outptr, _sum1); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_5x5_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
float* outptr2 = outptr + OW; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); | |||||
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); | |||||
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); | |||||
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); | |||||
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); | |||||
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); | |||||
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); | |||||
size_t h = 0; | |||||
for (; h + 1 < OH; h += 2) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _sum2 = GiLoadFloat32(outptr2); | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); | |||||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); | |||||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||||
GI_FLOAT32_t _r50 = GiLoadFloat32(r5); | |||||
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); | |||||
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); | |||||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); | |||||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k0123, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r11, _k0123, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k0123, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r13, _k0123, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r14, _k4567, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r20, _k4567, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k4567, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r22, _k4567, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r23, _k891011, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r24, _k891011, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r30, _k891011, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r31, _k891011, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r32, _k12131415, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r33, _k12131415, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r34, _k12131415, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r40, _k12131415, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r41, _k16171819, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r42, _k16171819, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r43, _k16171819, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r44, _k16171819, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r50, _k20212223, 0); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r51, _k20212223, 1); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r52, _k20212223, 2); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r53, _k20212223, 3); | |||||
_sum2 = GiSimdFmaLane(_sum2, _r54, _k24242424, 0); | |||||
GiStoreFloat32(outptr, _sum); | |||||
GiStoreFloat32(outptr2, _sum2); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
outptr += 4; | |||||
outptr2 += 4; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
r5 += tail_step + IW; | |||||
outptr += OW; | |||||
outptr2 += OW; | |||||
} | |||||
for (; h < OH; h++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); | |||||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); | |||||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride1::do_conv_7x7_stride1( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - OW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int width = OW >> 2; | |||||
rep(i, width) { | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||||
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); | |||||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); // 0 1 2 3 | |||||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); // 4 5 6 7 | |||||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 8); // 8 9 10 11 | |||||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); // 1 2 3 4 | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); // 2 3 4 5 | |||||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); // 3 4 5 6 | |||||
GI_FLOAT32_t _r05 = GiExtqFloat32(_r04, _r00n, 1); // 5 6 7 8 | |||||
GI_FLOAT32_t _r06 = GiExtqFloat32(_r04, _r00n, 2); // 6 7 8 9 | |||||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); | |||||
GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); | |||||
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); | |||||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 8); | |||||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||||
GI_FLOAT32_t _r15 = GiExtqFloat32(_r14, _r10n, 1); | |||||
GI_FLOAT32_t _r16 = GiExtqFloat32(_r14, _r10n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); | |||||
GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); | |||||
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); | |||||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 8); | |||||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||||
GI_FLOAT32_t _r25 = GiExtqFloat32(_r24, _r20n, 1); | |||||
GI_FLOAT32_t _r26 = GiExtqFloat32(_r24, _r20n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); | |||||
GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); | |||||
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); | |||||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||||
GI_FLOAT32_t _r30n = GiLoadFloat32(r3 + 8); | |||||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||||
GI_FLOAT32_t _r35 = GiExtqFloat32(_r34, _r30n, 1); | |||||
GI_FLOAT32_t _r36 = GiExtqFloat32(_r34, _r30n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); | |||||
GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); | |||||
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); | |||||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||||
GI_FLOAT32_t _r40n = GiLoadFloat32(r4 + 8); | |||||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||||
GI_FLOAT32_t _r45 = GiExtqFloat32(_r44, _r40n, 1); | |||||
GI_FLOAT32_t _r46 = GiExtqFloat32(_r44, _r40n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); | |||||
GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); | |||||
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); | |||||
GI_FLOAT32_t _r50 = GiLoadFloat32(r5); | |||||
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); | |||||
GI_FLOAT32_t _r50n = GiLoadFloat32(r5 + 8); | |||||
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); | |||||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); | |||||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); | |||||
GI_FLOAT32_t _r55 = GiExtqFloat32(_r54, _r50n, 1); | |||||
GI_FLOAT32_t _r56 = GiExtqFloat32(_r54, _r50n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); | |||||
GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); | |||||
GI_FLOAT32_t _k46474849 = | |||||
GiLd1qLaneFloat32(k6 + 4 + 2, GiLoadFloat32LowHalf(k6 + 4), 2); | |||||
GI_FLOAT32_t _r60 = GiLoadFloat32(r6); | |||||
GI_FLOAT32_t _r64 = GiLoadFloat32(r6 + 4); | |||||
GI_FLOAT32_t _r60n = GiLoadFloat32(r6 + 8); | |||||
GI_FLOAT32_t _r61 = GiExtqFloat32(_r60, _r64, 1); | |||||
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r64, 2); | |||||
GI_FLOAT32_t _r63 = GiExtqFloat32(_r60, _r64, 3); | |||||
GI_FLOAT32_t _r65 = GiExtqFloat32(_r64, _r60n, 1); | |||||
GI_FLOAT32_t _r66 = GiExtqFloat32(_r64, _r60n, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r64, _k46474849, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r65, _k46474849, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r66, _k46474849, 2); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r0 += 4; | |||||
r1 += 4; | |||||
r2 += 4; | |||||
r3 += 4; | |||||
r4 += 4; | |||||
r5 += 4; | |||||
r6 += 4; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
#include "src/common/simd_macro/epilogue.h" | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -13,7 +13,7 @@ | |||||
#include <cstddef> | #include <cstddef> | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace fp32 { | namespace fp32 { | ||||
namespace conv_stride1 { | namespace conv_stride1 { | ||||
@@ -31,7 +31,7 @@ void do_conv_7x7_stride1( | |||||
size_t OH, size_t OW, size_t IC); | size_t OH, size_t OW, size_t IC); | ||||
} // namespace conv_stride1 | } // namespace conv_stride1 | ||||
} // namespace fp32 | } // namespace fp32 | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,503 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include <algorithm> | |||||
#include "./do_conv_stride2.h" | |||||
#include "midout.h" | |||||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs2) | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
using namespace fp32; | |||||
using namespace conv_stride2; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
void conv_stride2::do_conv_2x2_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* k0 = filter; | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
GI_FLOAT32_t _outp = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); | |||||
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 | |||||
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 | |||||
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); | |||||
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); | |||||
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); | |||||
GI_FLOAT32_t _r10 = _r1.val[0]; | |||||
GI_FLOAT32_t _r11 = _r1.val[1]; | |||||
_outp = GiSimdFmaLane(_outp, _r10, _k0123, 2); | |||||
_outp = GiSimdFmaLane(_outp, _r11, _k0123, 3); | |||||
GiStoreFloat32(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
} | |||||
filter += 4; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_3x3_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 3; | |||||
const float* k2 = filter + 5; | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||||
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); | |||||
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); | |||||
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); | |||||
rep(h, OH) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
GI_FLOAT32_t _outp = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); | |||||
GI_FLOAT32_V2_t _r0n = GiLd2qFloat32(r0 + 8); | |||||
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 | |||||
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||||
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); | |||||
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); | |||||
_outp = GiSimdFmaLane(_outp, _r02, _k0123, 2); | |||||
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); | |||||
GI_FLOAT32_V2_t _r1n = GiLd2qFloat32(r1 + 8); | |||||
GI_FLOAT32_t _r10 = _r1.val[0]; | |||||
GI_FLOAT32_t _r11 = _r1.val[1]; | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1n.val[0], 1); | |||||
_outp = GiSimdFmaLane(_outp, _r10, _k3456, 0); | |||||
_outp = GiSimdFmaLane(_outp, _r11, _k3456, 1); | |||||
_outp = GiSimdFmaLane(_outp, _r12, _k3456, 2); | |||||
GI_FLOAT32_V2_t _r2 = GiLd2qFloat32(r2); | |||||
GI_FLOAT32_V2_t _r2n = GiLd2qFloat32(r2 + 8); | |||||
GI_FLOAT32_t _r20 = _r2.val[0]; | |||||
GI_FLOAT32_t _r21 = _r2.val[1]; | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2n.val[0], 1); | |||||
_outp = GiSimdFmaLane(_outp, _r20, _k6789, 0); | |||||
_outp = GiSimdFmaLane(_outp, _r21, _k6789, 1); | |||||
_outp = GiSimdFmaLane(_outp, _r22, _k6789, 2); | |||||
GiStoreFloat32(outptr, _outp); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
} | |||||
filter += 9; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_5x5_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); | |||||
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); | |||||
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); | |||||
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); | |||||
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); | |||||
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); | |||||
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); | |||||
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); | |||||
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); | |||||
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); | |||||
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; | |||||
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; | |||||
GI_FLOAT32_t _r10 = _r10_02461357.val[0]; | |||||
GI_FLOAT32_t _r11 = _r10_02461357.val[1]; | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); | |||||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); | |||||
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); | |||||
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); | |||||
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); | |||||
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; | |||||
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; | |||||
GI_FLOAT32_t _r20 = _r20_02461357.val[0]; | |||||
GI_FLOAT32_t _r21 = _r20_02461357.val[1]; | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); | |||||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); | |||||
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); | |||||
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); | |||||
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); | |||||
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; | |||||
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; | |||||
GI_FLOAT32_t _r30 = _r30_02461357.val[0]; | |||||
GI_FLOAT32_t _r31 = _r30_02461357.val[1]; | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); | |||||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); | |||||
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); | |||||
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); | |||||
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); | |||||
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; | |||||
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; | |||||
GI_FLOAT32_t _r40 = _r40_02461357.val[0]; | |||||
GI_FLOAT32_t _r41 = _r40_02461357.val[1]; | |||||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); | |||||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); | |||||
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
} | |||||
filter += 25; | |||||
} | |||||
} | |||||
void conv_stride2::do_conv_7x7_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | |||||
rep(ic, IC) { | |||||
const float* src_ptr = src + IW * IH * ic; | |||||
float* outptr = dst; | |||||
const float* r0 = src_ptr; | |||||
const float* r1 = src_ptr + IW; | |||||
const float* r2 = src_ptr + IW * 2; | |||||
const float* r3 = src_ptr + IW * 3; | |||||
const float* r4 = src_ptr + IW * 4; | |||||
const float* r5 = src_ptr + IW * 5; | |||||
const float* r6 = src_ptr + IW * 6; | |||||
const float* k0 = filter; | |||||
const float* k1 = filter + 7; | |||||
const float* k2 = filter + 14; | |||||
const float* k3 = filter + 21; | |||||
const float* k4 = filter + 28; | |||||
const float* k5 = filter + 35; | |||||
const float* k6 = filter + 42; | |||||
for (size_t i = 0; i < OH; i++) { | |||||
int nn = OW >> 2; | |||||
rep(i, nn) { | |||||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||||
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); | |||||
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); | |||||
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); | |||||
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||||
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||||
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||||
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 | |||||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 | |||||
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 | |||||
GI_FLOAT32_t _r05 = GiExtqFloat32(_r01, _r0_9111315, 2); // 5 7 9 11 | |||||
GI_FLOAT32_t _r06 = GiExtqFloat32(_r00, _r0_8101214, 3); // 6 8 10 12 | |||||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); | |||||
GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); | |||||
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); | |||||
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); | |||||
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); | |||||
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; | |||||
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; | |||||
GI_FLOAT32_t _r10 = _r10_02461357.val[0]; | |||||
GI_FLOAT32_t _r11 = _r10_02461357.val[1]; | |||||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); | |||||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); | |||||
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); | |||||
GI_FLOAT32_t _r15 = GiExtqFloat32(_r11, _r1_9111315, 2); | |||||
GI_FLOAT32_t _r16 = GiExtqFloat32(_r10, _r1_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); | |||||
GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); | |||||
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); | |||||
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); | |||||
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); | |||||
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; | |||||
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; | |||||
GI_FLOAT32_t _r20 = _r20_02461357.val[0]; | |||||
GI_FLOAT32_t _r21 = _r20_02461357.val[1]; | |||||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); | |||||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); | |||||
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); | |||||
GI_FLOAT32_t _r25 = GiExtqFloat32(_r21, _r2_9111315, 2); | |||||
GI_FLOAT32_t _r26 = GiExtqFloat32(_r20, _r2_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); | |||||
GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); | |||||
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); | |||||
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); | |||||
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); | |||||
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; | |||||
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; | |||||
GI_FLOAT32_t _r30 = _r30_02461357.val[0]; | |||||
GI_FLOAT32_t _r31 = _r30_02461357.val[1]; | |||||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); | |||||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); | |||||
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); | |||||
GI_FLOAT32_t _r35 = GiExtqFloat32(_r31, _r3_9111315, 2); | |||||
GI_FLOAT32_t _r36 = GiExtqFloat32(_r30, _r3_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); | |||||
GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); | |||||
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); | |||||
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); | |||||
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); | |||||
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; | |||||
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; | |||||
GI_FLOAT32_t _r40 = _r40_02461357.val[0]; | |||||
GI_FLOAT32_t _r41 = _r40_02461357.val[1]; | |||||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); | |||||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); | |||||
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); | |||||
GI_FLOAT32_t _r45 = GiExtqFloat32(_r41, _r4_9111315, 2); | |||||
GI_FLOAT32_t _r46 = GiExtqFloat32(_r40, _r4_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); | |||||
GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); | |||||
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); | |||||
GI_FLOAT32_V2_t _r50_02461357 = GiLd2qFloat32(r5); | |||||
GI_FLOAT32_V2_t _r50nx2 = GiLd2qFloat32(r5 + 8); | |||||
GI_FLOAT32_t _r5_8101214 = _r50nx2.val[0]; | |||||
GI_FLOAT32_t _r5_9111315 = _r50nx2.val[1]; | |||||
GI_FLOAT32_t _r50 = _r50_02461357.val[0]; | |||||
GI_FLOAT32_t _r51 = _r50_02461357.val[1]; | |||||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r5_8101214, 1); | |||||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r51, _r5_9111315, 1); | |||||
GI_FLOAT32_t _r54 = GiExtqFloat32(_r50, _r5_8101214, 2); | |||||
GI_FLOAT32_t _r55 = GiExtqFloat32(_r51, _r5_9111315, 2); | |||||
GI_FLOAT32_t _r56 = GiExtqFloat32(_r50, _r5_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); | |||||
GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); | |||||
GI_FLOAT32_t _k45464748 = GiLoadFloat32(k6 + 3); | |||||
GI_FLOAT32_V2_t _r60_02461357 = GiLd2qFloat32(r6); | |||||
GI_FLOAT32_V2_t _r60nx2 = GiLd2qFloat32(r6 + 8); | |||||
GI_FLOAT32_t _r6_8101214 = _r60nx2.val[0]; | |||||
GI_FLOAT32_t _r6_9111315 = _r60nx2.val[1]; | |||||
GI_FLOAT32_t _r60 = _r60_02461357.val[0]; | |||||
GI_FLOAT32_t _r61 = _r60_02461357.val[1]; | |||||
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r6_8101214, 1); | |||||
GI_FLOAT32_t _r63 = GiExtqFloat32(_r61, _r6_9111315, 1); | |||||
GI_FLOAT32_t _r64 = GiExtqFloat32(_r60, _r6_8101214, 2); | |||||
GI_FLOAT32_t _r65 = GiExtqFloat32(_r61, _r6_9111315, 2); | |||||
GI_FLOAT32_t _r66 = GiExtqFloat32(_r60, _r6_8101214, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); | |||||
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); | |||||
_sum = GiSimdFmaLane(_sum, _r64, _k45464748, 1); | |||||
_sum = GiSimdFmaLane(_sum, _r65, _k45464748, 2); | |||||
_sum = GiSimdFmaLane(_sum, _r66, _k45464748, 3); | |||||
GiStoreFloat32(outptr, _sum); | |||||
r0 += 8; | |||||
r1 += 8; | |||||
r2 += 8; | |||||
r3 += 8; | |||||
r4 += 8; | |||||
r5 += 8; | |||||
r6 += 8; | |||||
outptr += 4; | |||||
} | |||||
r0 += tail_step; | |||||
r1 += tail_step; | |||||
r2 += tail_step; | |||||
r3 += tail_step; | |||||
r4 += tail_step; | |||||
r5 += tail_step; | |||||
r6 += tail_step; | |||||
} | |||||
filter += 49; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -13,7 +13,7 @@ | |||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace fp32 { | namespace fp32 { | ||||
namespace conv_stride2 { | namespace conv_stride2 { | ||||
void do_conv_2x2_stride2( | void do_conv_2x2_stride2( | ||||
@@ -30,7 +30,7 @@ void do_conv_7x7_stride2( | |||||
size_t OH, size_t OW, size_t IC); | size_t OH, size_t OW, size_t IC); | ||||
} // namespace conv_stride2 | } // namespace conv_stride2 | ||||
} // namespace fp32 | } // namespace fp32 | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,21 +11,21 @@ | |||||
*/ | */ | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/block_helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/block_helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
using conv_fun = std::function<void( | using conv_fun = std::function<void( | ||||
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | ||||
const CpuNDRange& ncb_range)>; | const CpuNDRange& ncb_range)>; | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) | |||||
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw44_stride1) | |||||
namespace { | namespace { | ||||
static inline size_t get_perthread_cache_bytes( | static inline size_t get_perthread_cache_bytes( | ||||
@@ -156,7 +156,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_arm_common_conv_bias_fp32_nchw44_stride1, | |||||
megdnn_fallback_conv_bias_fp32_nchw44_stride1, | |||||
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { | midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { | ||||
return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
} | } | ||||
@@ -175,7 +175,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_k | |||||
// shape runtime | // shape runtime | ||||
#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ | #define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ | |||||
megdnn_fallback_conv_bias_fp32_nchw44_stride1, \ | |||||
midout_iv(#filter #bias_mode #stride #op##_hash)) { \ | midout_iv(#filter #bias_mode #stride #op##_hash)) { \ | ||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | ||||
} \ | } \ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_stride1_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -10,10 +10,10 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace conv_bias { | namespace conv_bias { | ||||
template <BiasMode bias_mode, typename Op, int filter_size, int stride> | template <BiasMode bias_mode, typename Op, int filter_size, int stride> | ||||
@@ -28,5 +28,5 @@ void pack_src_fp32_nchw44( | |||||
const int pad_top, const int pad_bottom, const int ic, const int ic_stride); | const int pad_top, const int pad_bottom, const int ic, const int ic_stride); | ||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn |
@@ -1,6 +1,6 @@ | |||||
/** | /** | ||||
* \file | * \file | ||||
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp | |||||
dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -12,21 +12,21 @@ | |||||
*/ | */ | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/common/nchw_nchwxx_valid.h" | #include "src/common/nchw_nchwxx_valid.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
using conv_fun = std::function<void( | using conv_fun = std::function<void( | ||||
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | ||||
const CpuNDRange& ncb_range)>; | const CpuNDRange& ncb_range)>; | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44) | |||||
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw_nchw44) | |||||
namespace { | namespace { | ||||
static inline int block_helper( | static inline int block_helper( | ||||
const int nthread, const int amount, const int per_unit_bytes) { | const int nthread, const int amount, const int per_unit_bytes) { | ||||
@@ -195,7 +195,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_arm_common_conv_bias_fp32_nchw_nchw44, | |||||
megdnn_fallback_conv_bias_fp32_nchw_nchw44, | |||||
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { | midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { | ||||
return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
} | } | ||||
@@ -214,7 +214,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHWNCHW44:: | |||||
// shape runtime | // shape runtime | ||||
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | #define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | ||||
MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ | |||||
megdnn_fallback_conv_bias_fp32_nchw_nchw44, \ | |||||
midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | ||||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | ||||
} \ | } \ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,15 +11,14 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace fp32_direct_nchw_nchw44 { | namespace fp32_direct_nchw_nchw44 { | ||||
static inline void pack_weight_fp32_nchw_nchw44( | static inline void pack_weight_fp32_nchw_nchw44( | ||||
@@ -34,8 +33,8 @@ static inline void pack_weight_fp32_nchw_nchw44( | |||||
for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { | for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { | ||||
for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { | for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { | ||||
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | ||||
float32x4_t vsrc = vld1q_f32(in_ptr_oc); | |||||
vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); | |||||
GI_FLOAT32_t vsrc = GiLoadFloat32(in_ptr_oc); | |||||
GiStoreFloat32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); | |||||
in_ptr_oc += oc_step; | in_ptr_oc += oc_step; | ||||
} | } | ||||
dst_ptr_oc += oc_step; | dst_ptr_oc += oc_step; | ||||
@@ -51,6 +50,6 @@ void conv_direct_fp32_nchw_nchw44( | |||||
const int, const int); | const int, const int); | ||||
} // namespace fp32_direct_nchw_nchw44 | } // namespace fp32_direct_nchw_nchw44 | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,14 +11,13 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | ||||
struct FilterTransform6X3 { | struct FilterTransform6X3 { | ||||
@@ -65,8 +64,8 @@ struct FilterTransform6X3 { | |||||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | ||||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | ||||
float32x4_t zeros = vdupq_n_f32(0.0f); | |||||
g2.value = vextq_f32(g2.value, zeros, 1); | |||||
GI_FLOAT32_t zeros = GiZeroFloat32(); | |||||
g2.value = GiExtqFloat32(g2.value, zeros, 1); | |||||
#define cb(i) Vector<float, 4> wd##i; | #define cb(i) Vector<float, 4> wd##i; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -106,7 +105,6 @@ struct FilterTransform6X3 { | |||||
} | } | ||||
#else | #else | ||||
#define cb(i) \ | #define cb(i) \ | ||||
do { \ | do { \ | ||||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | ||||
@@ -128,7 +126,7 @@ struct FilterTransform6X3 { | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | ||||
mid_buf1 += 8; \ | mid_buf1 += 8; \ | ||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) | |||||
float* mid_buf1 = transform_mid_buf; | float* mid_buf1 = transform_mid_buf; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -154,7 +152,7 @@ struct FilterTransform6X3 { | |||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#undef GET_VECTOR_ELEM | #undef GET_VECTOR_ELEM | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,196 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||||
GI_FLOAT32_V2_t a0, a1; | |||||
a0.val[0] = GiLoadFloat32(src + 0 * lda); | |||||
a0.val[1] = GiLoadFloat32(src + 1 * lda); | |||||
a1.val[0] = GiLoadFloat32(src + 2 * lda); | |||||
a1.val[1] = GiLoadFloat32(src + 3 * lda); | |||||
GI_FLOAT32_V2_t b0 = GiZipqFloat32(a0.val[0], a1.val[0]); | |||||
GI_FLOAT32_V2_t b1 = GiZipqFloat32(a0.val[1], a1.val[1]); | |||||
GI_FLOAT32_V2_t c0 = GiZipqFloat32(b0.val[0], b1.val[0]); | |||||
GI_FLOAT32_V2_t c1 = GiZipqFloat32(b0.val[1], b1.val[1]); | |||||
GiStoreFloat32(dst + 0 * ldb, c0.val[0]); | |||||
GiStoreFloat32(dst + 1 * ldb, c0.val[1]); | |||||
GiStoreFloat32(dst + 2 * ldb, c1.val[0]); | |||||
GiStoreFloat32(dst + 3 * ldb, c1.val[1]); | |||||
} | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
#define MATRIX_MUL4x4(sum, a, b) \ | |||||
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##0, a##0, 0); \ | |||||
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##1, a##0, 1); \ | |||||
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##2, a##0, 2); \ | |||||
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##3, a##0, 3); \ | |||||
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##0, a##1, 0); \ | |||||
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##1, a##1, 1); \ | |||||
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##2, a##1, 2); \ | |||||
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##3, a##1, 3); \ | |||||
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##0, a##2, 0); \ | |||||
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##1, a##2, 1); \ | |||||
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##2, a##2, 2); \ | |||||
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##3, a##2, 3); \ | |||||
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##0, a##3, 0); \ | |||||
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##1, a##3, 1); \ | |||||
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##2, a##3, 2); \ | |||||
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##3, a##3, 3); | |||||
#define CONCAT(a, idx) a##idx | |||||
#if MEGDNN_AARCH64 | |||||
//! ret and a are type Vector<float, 8> | |||||
#define TRANSPOSE_8x8(a, ret) \ | |||||
do { \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ | |||||
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ | |||||
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ | |||||
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ | |||||
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||||
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||||
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x3(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||||
#else | |||||
#define TRANSPOSE_8x4(a, ret) \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||||
CONCAT(ret, 0).value.val[0] = \ | |||||
GiCombineFloat32(GiGetLowFloat32(b0.val[0]), GiGetLowFloat32(b1.val[0])); \ | |||||
CONCAT(ret, 1).value.val[0] = GiCombineFloat32( \ | |||||
GiGetHighFloat32(b0.val[0]), GiGetHighFloat32(b1.val[0])); \ | |||||
CONCAT(ret, 2).value.val[0] = \ | |||||
GiCombineFloat32(GiGetLowFloat32(b0.val[1]), GiGetLowFloat32(b1.val[1])); \ | |||||
CONCAT(ret, 3).value.val[0] = GiCombineFloat32( \ | |||||
GiGetHighFloat32(b0.val[1]), GiGetHighFloat32(b1.val[1])); \ | |||||
CONCAT(ret, 0).value.val[1] = \ | |||||
GiCombineFloat32(GiGetLowFloat32(b2.val[0]), GiGetLowFloat32(b3.val[0])); \ | |||||
CONCAT(ret, 1).value.val[1] = GiCombineFloat32( \ | |||||
GiGetHighFloat32(b2.val[0]), GiGetHighFloat32(b3.val[0])); \ | |||||
CONCAT(ret, 2).value.val[1] = \ | |||||
GiCombineFloat32(GiGetLowFloat32(b2.val[1]), GiGetLowFloat32(b3.val[1])); \ | |||||
CONCAT(ret, 3).value.val[1] = GiCombineFloat32( \ | |||||
GiGetHighFloat32(b2.val[1]), GiGetHighFloat32(b3.val[1])); | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -11,14 +11,15 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, winograd_2x3_4x4_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY( | |||||
float, float, float, float, 2, 3, 4, 4, winograd_gi_2x3_4x4_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) | ||||
@@ -37,7 +38,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY( | |||||
MEGDNN_REG_WINOGRAD_STRATEGY( | MEGDNN_REG_WINOGRAD_STRATEGY( | ||||
float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) | float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F23) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
struct InputTransform2X3 { | struct InputTransform2X3 { | ||||
@@ -40,15 +39,15 @@ struct InputTransform2X3 { | |||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | ||||
for (size_t ico = 0; ico < 4; ++ico) { | for (size_t ico = 0; ico < 4; ++ico) { | ||||
if (ic + ico < IC) { | if (ic + ico < IC) { | ||||
auto v0 = vld1q_f32(input_ptr); | |||||
auto v1 = vld1q_f32(input_ptr + IW); | |||||
auto v2 = vld1q_f32(input_ptr + IW * 2); | |||||
auto v3 = vld1q_f32(input_ptr + IW * 3); | |||||
vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); | |||||
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); | |||||
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); | |||||
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); | |||||
auto v0 = GiLoadFloat32(input_ptr); | |||||
auto v1 = GiLoadFloat32(input_ptr + IW); | |||||
auto v2 = GiLoadFloat32(input_ptr + IW * 2); | |||||
auto v3 = GiLoadFloat32(input_ptr + IW * 3); | |||||
GiStoreFloat32(patch + ico * 4 * alpha + 0 * 4, v0); | |||||
GiStoreFloat32(patch + ico * 4 * alpha + 1 * 4, v1); | |||||
GiStoreFloat32(patch + ico * 4 * alpha + 2 * 4, v2); | |||||
GiStoreFloat32(patch + ico * 4 * alpha + 3 * 4, v3); | |||||
input_ptr += IH * IW; | input_ptr += IH * IW; | ||||
} | } | ||||
} | } | ||||
@@ -197,18 +196,18 @@ struct OutputTransform2X3 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | |||||
void winograd_2x3_4x4_f::filter( | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_gi_2x3_4x4_f) | |||||
void winograd_gi_2x3_4x4_f::filter( | |||||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | const float* filter, float* filter_transform_buf, float* transform_mid_buf, | ||||
size_t OC, size_t IC, size_t oc_start, size_t oc_end) { | size_t OC, size_t IC, size_t oc_start, size_t oc_end) { | ||||
constexpr int alpha = 2 + 3 - 1; | constexpr int alpha = 2 + 3 - 1; | ||||
//! G * g * GT | //! G * g * GT | ||||
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||||
GI_FLOAT32_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||||
g3{0, 0, 1, 0}; | g3{0, 0, 1, 0}; | ||||
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||||
GI_FLOAT32_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||||
gt3{0, 0, 0, 0}; | gt3{0, 0, 0, 0}; | ||||
size_t OCB = OC / 4; | size_t OCB = OC / 4; | ||||
size_t ICB = IC / 4; | size_t ICB = IC / 4; | ||||
@@ -225,33 +224,33 @@ void winograd_2x3_4x4_f::filter( | |||||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | //! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | ||||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | //! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | ||||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | //! 0 0 1 0 0 0 0 0 0 0 0 0 | ||||
float32x4_t vf0 = vld1q_f32(filter_ptr); | |||||
float32x4_t vf1 = vld1q_f32(filter_ptr + 4); | |||||
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); | |||||
float32x4_t v3(vdupq_n_f32(0)); | |||||
auto vtmp = vextq_f32(vf1, vf2, 2); | |||||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||||
float32x4_t v2(vtmp); | |||||
vtmp = vextq_f32(vf0, vf1, 3); | |||||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||||
float32x4_t v1(vtmp); | |||||
vtmp = vsetq_lane_f32(0, vf0, 3); | |||||
float32x4_t v0(vtmp); | |||||
float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), | |||||
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); | |||||
GI_FLOAT32_t vf0 = GiLoadFloat32(filter_ptr); | |||||
GI_FLOAT32_t vf1 = GiLoadFloat32(filter_ptr + 4); | |||||
GI_FLOAT32_t vf2 = GiBroadcastFloat32(filter_ptr[8]); | |||||
GI_FLOAT32_t v3(GiBroadcastFloat32(0)); | |||||
auto vtmp = GiExtqFloat32(vf1, vf2, 2); | |||||
vtmp = GiSetqLaneFloat32(0, vtmp, 3); | |||||
GI_FLOAT32_t v2(vtmp); | |||||
vtmp = GiExtqFloat32(vf0, vf1, 3); | |||||
vtmp = GiSetqLaneFloat32(0, vtmp, 3); | |||||
GI_FLOAT32_t v1(vtmp); | |||||
vtmp = GiSetqLaneFloat32(0, vf0, 3); | |||||
GI_FLOAT32_t v0(vtmp); | |||||
GI_FLOAT32_t vsum0 = GiBroadcastFloat32(0), vsum1 = GiBroadcastFloat32(0), | |||||
vsum2 = GiBroadcastFloat32(0), vsum3 = GiBroadcastFloat32(0); | |||||
MATRIX_MUL4x4(vsum, g, v); | MATRIX_MUL4x4(vsum, g, v); | ||||
float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), | |||||
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); | |||||
GI_FLOAT32_t vres0 = GiBroadcastFloat32(0), vres1 = GiBroadcastFloat32(0), | |||||
vres2 = GiBroadcastFloat32(0), vres3 = GiBroadcastFloat32(0); | |||||
MATRIX_MUL4x4(vres, vsum, gt); | MATRIX_MUL4x4(vres, vsum, gt); | ||||
vst1q_f32(transform_mid_buf, vres0); | |||||
vst1q_f32(transform_mid_buf + 4, vres1); | |||||
vst1q_f32(transform_mid_buf + 8, vres2); | |||||
vst1q_f32(transform_mid_buf + 12, vres3); | |||||
GiStoreFloat32(transform_mid_buf, vres0); | |||||
GiStoreFloat32(transform_mid_buf + 4, vres1); | |||||
GiStoreFloat32(transform_mid_buf + 8, vres2); | |||||
GiStoreFloat32(transform_mid_buf + 12, vres3); | |||||
size_t ocb = oc / 4; | size_t ocb = oc / 4; | ||||
size_t oc4 = oc % 4; | size_t oc4 = oc % 4; | ||||
@@ -266,7 +265,7 @@ void winograd_2x3_4x4_f::filter( | |||||
} | } | ||||
} | } | ||||
void winograd_2x3_4x4_f::input( | |||||
void winograd_gi_2x3_4x4_f::input( | |||||
const float* input, float* input_transform_buf, float* transform_mid_buf, | const float* input, float* input_transform_buf, float* transform_mid_buf, | ||||
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, | size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, | ||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
@@ -304,7 +303,7 @@ void winograd_2x3_4x4_f::input( | |||||
} | } | ||||
} | } | ||||
void winograd_2x3_4x4_f::output( | |||||
void winograd_gi_2x3_4x4_f::output( | |||||
const float* output_transform_buf, const float* bias, float* output, | const float* output_transform_buf, const float* bias, float* output, | ||||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, | float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, | ||||
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, | size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, | ||||
@@ -322,8 +321,8 @@ void winograd_2x3_4x4_f::output( | |||||
auto nw = index % units_w; | auto nw = index % units_w; | ||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F23, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | ||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | ||||
nr_units_in_tile, src_dtype, dst_dtype); | nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -333,7 +332,7 @@ void winograd_2x3_4x4_f::output( | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F45) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
struct FilterTransform4X5 { | struct FilterTransform4X5 { | ||||
@@ -126,9 +125,9 @@ struct FilterTransform4X5 { | |||||
#undef cb | #undef cb | ||||
FILTER_TRANSFORM(g, Gg) | FILTER_TRANSFORM(g, Gg) | ||||
float32x4x2_t vgr; | |||||
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||||
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||||
GI_FLOAT32_V2_t vgr; | |||||
GI_FLOAT32_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||||
GI_FLOAT32_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||||
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; | vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; | ||||
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; | vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; | ||||
Vector<float, 8> Ggt4(vgr); | Vector<float, 8> Ggt4(vgr); | ||||
@@ -167,8 +166,10 @@ struct InputTransform4X5 { | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | ||||
} while (0) | } while (0) | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -345,22 +346,22 @@ struct OutputTransform4X5 { | |||||
#undef cb | #undef cb | ||||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | ||||
float32x4_t bias0; | |||||
GI_FLOAT32_t bias0; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
bias0 = GiBroadcastFloat32(bias[oc]); | |||||
} | } | ||||
rep(i, 4) { | rep(i, 4) { | ||||
size_t oh = oh_start + i; | size_t oh = oh_start + i; | ||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
item0 = vaddq_f32(item0, bias0); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
} else if (bmode == BiasMode::BIAS) { | } else if (bmode == BiasMode::BIAS) { | ||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
item0 = vaddq_f32(item0, bias0); | |||||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
} | } | ||||
item0 = op(item0); | item0 = op(item0); | ||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
mid_buf1 += 4; | mid_buf1 += 4; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -388,7 +389,7 @@ struct OutputTransform4X5 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) | ||||
@@ -448,8 +449,8 @@ void winograd_4x5_1x1_f::output( | |||||
auto nw = index % units_w; | auto nw = index % units_w; | ||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F45, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | ||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | ||||
nr_units_in_tile, src_dtype, dst_dtype); | nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -459,7 +460,7 @@ void winograd_4x5_1x1_f::output( | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F54) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
struct FilterTransform5X4 { | struct FilterTransform5X4 { | ||||
@@ -94,7 +93,6 @@ struct FilterTransform5X4 { | |||||
transform_mid_buf[j * alpha + i]; | transform_mid_buf[j * alpha + i]; | ||||
} | } | ||||
#else | #else | ||||
#define cb(i) \ | #define cb(i) \ | ||||
do { \ | do { \ | ||||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | ||||
@@ -117,7 +115,7 @@ struct FilterTransform5X4 { | |||||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | ||||
mid_buf1 += 8; \ | mid_buf1 += 8; \ | ||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) | |||||
float* mid_buf1 = transform_mid_buf; | float* mid_buf1 = transform_mid_buf; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -154,8 +152,10 @@ struct InputTransform5X4 { | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | ||||
} while (0) | } while (0) | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -348,29 +348,29 @@ struct OutputTransform5X4 { | |||||
#undef cb | #undef cb | ||||
if (oh_start + 5 <= OH && ow_start + 5 <= OW) { | if (oh_start + 5 <= OH && ow_start + 5 <= OW) { | ||||
float32x4_t bias0; | |||||
GI_FLOAT32_t bias0; | |||||
float32_t bias1; | float32_t bias1; | ||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
bias0 = GiBroadcastFloat32(bias[oc]); | |||||
bias1 = bias[oc]; | bias1 = bias[oc]; | ||||
} | } | ||||
rep(i, 5) { | rep(i, 5) { | ||||
size_t oh = oh_start + i; | size_t oh = oh_start + i; | ||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||||
float32_t item1 = mid_buf1[4]; | float32_t item1 = mid_buf1[4]; | ||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
item0 = vaddq_f32(item0, bias0); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
item1 = item1 + bias1; | item1 = item1 + bias1; | ||||
} else if (bmode == BiasMode::BIAS) { | } else if (bmode == BiasMode::BIAS) { | ||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; | bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; | ||||
item0 = vaddq_f32(item0, bias0); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
item1 = item1 + bias1; | item1 = item1 + bias1; | ||||
} | } | ||||
item0 = op(item0); | item0 = op(item0); | ||||
item1 = op(item1); | item1 = op(item1); | ||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
output[oc * OH * OW + oh * OW + ow_start + 4] = item1; | output[oc * OH * OW + oh * OW + ow_start + 4] = item1; | ||||
mid_buf1 += 5; | mid_buf1 += 5; | ||||
@@ -400,7 +400,7 @@ struct OutputTransform5X4 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) | ||||
@@ -461,8 +461,8 @@ void winograd_5x4_1x1_f::output( | |||||
auto nw = index % units_w; | auto nw = index % units_w; | ||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F54, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | ||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | ||||
nr_units_in_tile, src_dtype, dst_dtype); | nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -472,7 +472,7 @@ void winograd_5x4_1x1_f::output( | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
/** | /** | ||||
@@ -57,8 +56,10 @@ namespace { | |||||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | ||||
} while (0); | } while (0); | ||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||||
struct InputTransform6X3 { | struct InputTransform6X3 { | ||||
template <bool inner> | template <bool inner> | ||||
static void transform( | static void transform( | ||||
@@ -271,31 +272,31 @@ struct OutputTransform6X3 { | |||||
#undef cb | #undef cb | ||||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | ||||
float32x4_t bias0; | |||||
GI_FLOAT32_t bias0; | |||||
float32x2_t bias1; | float32x2_t bias1; | ||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
bias0 = vdupq_n_f32(bias[oc]); | |||||
bias1 = vdup_n_f32(bias[oc]); | |||||
bias0 = GiBroadcastFloat32(bias[oc]); | |||||
bias1 = GiDupFloat32(bias[oc]); | |||||
} | } | ||||
rep(i, 6) { | rep(i, 6) { | ||||
size_t oh = oh_start + i; | size_t oh = oh_start + i; | ||||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||||
float32x2_t item1 = vld1_f32(mid_buf1 + 4); | |||||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||||
float32x2_t item1 = GiLdFloat32(mid_buf1 + 4); | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = vadd_f32(item1, bias1); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
item1 = GiAddDFloat32(item1, bias1); | |||||
} else if (bmode == BiasMode::BIAS) { | } else if (bmode == BiasMode::BIAS) { | ||||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + 4); | |||||
item0 = vaddq_f32(item0, bias0); | |||||
item1 = vadd_f32(item1, bias1); | |||||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
bias1 = GiLdFloat32(bias + oc * OH * OW + oh * OW + ow_start + 4); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
item1 = GiAddDFloat32(item1, bias1); | |||||
} | } | ||||
item0 = op(item0); | item0 = op(item0); | ||||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); | |||||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); | |||||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||||
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 0)), item1, 0); | |||||
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 1)), item1, 1); | |||||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
GiSt1Float32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||||
mid_buf1 += 6; | mid_buf1 += 6; | ||||
} | } | ||||
@@ -325,7 +326,7 @@ struct OutputTransform6X3 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) | ||||
@@ -385,8 +386,8 @@ void winograd_6x3_1x1_f::output( | |||||
auto nw = index % units_w; | auto nw = index % units_w; | ||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F63, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | ||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | ||||
nr_units_in_tile, src_dtype, dst_dtype); | nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -396,7 +397,7 @@ void winograd_6x3_1x1_f::output( | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/common/winograd/winograd_helper.h" | #include "src/common/winograd/winograd_helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_4x4) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
@@ -41,16 +40,16 @@ struct InputTransform6X3 { | |||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | ||||
for (size_t ico = 0; ico < 4; ++ico) { | for (size_t ico = 0; ico < 4; ++ico) { | ||||
if (ic + ico < IC) { | if (ic + ico < IC) { | ||||
#define cb(i) \ | |||||
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ | |||||
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); | |||||
#define cb(i) \ | |||||
auto v##i##0 = GiLoadFloat32(input_ptr + IW * i); \ | |||||
auto v##i##1 = GiLoadFloat32(input_ptr + IW * i + 4); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) \ | |||||
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||||
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||||
#define cb(i) \ | |||||
GiStoreFloat32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||||
GiStoreFloat32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -255,7 +254,7 @@ struct OutputTransform6X3 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) | ||||
@@ -323,8 +322,8 @@ void winograd_6x3_4x4_f::output( | |||||
auto nw = index % units_w; | auto nw = index % units_w; | ||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | ||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | ||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | ||||
nr_units_in_tile, src_dtype, dst_dtype); | nr_units_in_tile, src_dtype, dst_dtype); | ||||
@@ -334,7 +333,7 @@ void winograd_6x3_4x4_f::output( | |||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_nchw44_fp32_F23_mk4) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
constexpr size_t alpha = 2 + 3 - 1; | constexpr size_t alpha = 2 + 3 - 1; | ||||
@@ -72,8 +71,9 @@ struct InputTransformF23_NCHW44 { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | for (int ih = ih0_act; ih < ih1_act; ++ih) { | ||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | for (int iw = iw0_act; iw < iw1_act; ++iw) { | ||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | size_t iho = ih - ih_start, iwo = iw - iw_start; | ||||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||||
vst1q_f32(patchT + iho * alpha * pack_size + iwo * pack_size, src); | |||||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||||
GiStoreFloat32( | |||||
patchT + iho * alpha * pack_size + iwo * pack_size, src); | |||||
} | } | ||||
} | } | ||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
@@ -190,7 +190,7 @@ struct OutputTransformF23_NCHW44 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) | ||||
@@ -313,14 +313,14 @@ void winograd_F23_mk4_f_nchw44::output( | |||||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | ||||
"NCHW44 Winograd filter transform requires OC is times of 4"); | "NCHW44 Winograd filter transform requires OC is times of 4"); | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, | |||||
nonline_mode); | nonline_mode); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/common/winograd/winograd_helper.h" | #include "src/common/winograd/winograd_helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_mk4) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_mk4) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
@@ -49,11 +48,11 @@ struct InputTransformF63_NCHW44 { | |||||
const float* input_ptr = | const float* input_ptr = | ||||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | ||||
for (size_t ih = 0; ih < alpha; ih++) { | for (size_t ih = 0; ih < alpha; ih++) { | ||||
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); | |||||
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||||
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
#undef cb | #undef cb | ||||
input_ptr += IW4; | input_ptr += IW4; | ||||
@@ -68,8 +67,9 @@ struct InputTransformF63_NCHW44 { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | for (int ih = ih0_act; ih < ih1_act; ++ih) { | ||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | for (int iw = iw0_act; iw < iw1_act; ++iw) { | ||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | size_t iho = ih - ih_start, iwo = iw - iw_start; | ||||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||||
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||||
GiStoreFloat32( | |||||
patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -83,10 +83,10 @@ struct InputTransformF63_NCHW44 { | |||||
size_t ICB = IC / pack_size; | size_t ICB = IC / pack_size; | ||||
size_t icb = ic / pack_size; | size_t icb = ic / pack_size; | ||||
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; | |||||
float32x4_t v0 = vld1q_f32(input_parameters + 0); | |||||
float32x4_t v1 = vld1q_f32(input_parameters + 4); | |||||
float32x4_t v2 = vld1q_f32(input_parameters + 8); | |||||
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7; | |||||
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); | |||||
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); | |||||
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); | |||||
//! B | //! B | ||||
//! 1 0 0 0 0 0 0 0 | //! 1 0 0 0 0 0 0 0 | ||||
@@ -98,57 +98,57 @@ struct InputTransformF63_NCHW44 { | |||||
//! -1 1 1 1 1 1 1 0 | //! -1 1 1 1 1 1 1 0 | ||||
//! 0 0 0 0 0 0 0 1 | //! 0 0 0 0 0 0 0 1 | ||||
#define cb(i) \ | |||||
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||||
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||||
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||||
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||||
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||||
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||||
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||||
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||||
auto t##i##1 = d6; \ | |||||
auto t##i##2 = d6; \ | |||||
auto t##i##3 = d6; \ | |||||
auto t##i##4 = d6; \ | |||||
auto t##i##5 = d6; \ | |||||
auto t##i##6 = d6; \ | |||||
t##i##0 = t##i##0 - d6; \ | |||||
t##i##1 = t##i##1 + d1; \ | |||||
t##i##2 = t##i##2 - d1; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ | |||||
t##i##7 = t##i##7 - d1; \ | |||||
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ | |||||
t##i##1 = t##i##1 + d2; \ | |||||
t##i##2 = t##i##2 + d2; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ | |||||
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ | |||||
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ | |||||
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ | |||||
t##i##1 = t##i##1 + d5; \ | |||||
t##i##2 = t##i##2 - d5; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ | |||||
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); | |||||
#define cb(i) \ | |||||
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||||
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||||
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||||
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||||
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||||
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||||
auto t##i##0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||||
auto t##i##7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||||
auto t##i##1 = d6; \ | |||||
auto t##i##2 = d6; \ | |||||
auto t##i##3 = d6; \ | |||||
auto t##i##4 = d6; \ | |||||
auto t##i##5 = d6; \ | |||||
auto t##i##6 = d6; \ | |||||
t##i##0 = t##i##0 - d6; \ | |||||
t##i##1 = t##i##1 + d1; \ | |||||
t##i##2 = t##i##2 - d1; \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d1, v0, 2); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d1, v0, 2); \ | |||||
t##i##5 = GiSimdFmaLane(t##i##5, d1, v1, 2); \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d1, v1, 2); \ | |||||
t##i##7 = t##i##7 - d1; \ | |||||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 0); \ | |||||
t##i##1 = t##i##1 + d2; \ | |||||
t##i##2 = t##i##2 + d2; \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d2, v0, 3); \ | |||||
t##i##4 = GiSimdFmaLane(t##i##4, d2, v0, 3); \ | |||||
t##i##5 = GiSimdFmaLane(t##i##5, d2, v1, 3); \ | |||||
t##i##6 = GiSimdFmaLane(t##i##6, d2, v1, 3); \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d3, v0, 1); \ | |||||
t##i##2 = GiSimdFmaLane(t##i##2, d3, v0, 1); \ | |||||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d3, v1, 0); \ | |||||
t##i##4 = GiSimdFmaLane(t##i##4, d3, v1, 0); \ | |||||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d3, v1, 0); \ | |||||
t##i##6 = GiSimdFmaLane(t##i##6, d3, v1, 0); \ | |||||
t##i##7 = GiSimdFmaLane(t##i##7, d3, v0, 0); \ | |||||
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 0); \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d4, v0, 1); \ | |||||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d4, v0, 1); \ | |||||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v1, 1); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d4, v1, 1); \ | |||||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d4, v2, 0); \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d4, v2, 0); \ | |||||
t##i##1 = t##i##1 + d5; \ | |||||
t##i##2 = t##i##2 - d5; \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d5, v1, 2); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d5, v1, 2); \ | |||||
t##i##5 = GiSimdFmaLane(t##i##5, d5, v0, 2); \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v0, 2); \ | |||||
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v0, 0); | |||||
UNROLL_CALL_RAW(8, cb); | UNROLL_CALL_RAW(8, cb); | ||||
#undef cb | #undef cb | ||||
@@ -164,75 +164,75 @@ struct InputTransformF63_NCHW44 { | |||||
d0 = d0 - t6##i; \ | d0 = d0 - t6##i; \ | ||||
d1 = d1 + t1##i; \ | d1 = d1 + t1##i; \ | ||||
d2 = d2 - t1##i; \ | d2 = d2 - t1##i; \ | ||||
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ | |||||
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ | |||||
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ | |||||
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ | |||||
d3 = GiSimdFmaLane(d3, t1##i, v0, 2); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t1##i, v0, 2); \ | |||||
d5 = GiSimdFmaLane(d5, t1##i, v1, 2); \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t1##i, v1, 2); \ | |||||
d7 = d7 - t1##i; \ | d7 = d7 - t1##i; \ | ||||
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ | |||||
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 0); \ | |||||
d1 = d1 + t2##i; \ | d1 = d1 + t2##i; \ | ||||
d2 = d2 + t2##i; \ | d2 = d2 + t2##i; \ | ||||
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ | |||||
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ | |||||
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ | |||||
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ | |||||
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ | |||||
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ | |||||
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ | |||||
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ | |||||
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ | |||||
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ | |||||
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ | |||||
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ | |||||
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ | |||||
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ | |||||
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ | |||||
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ | |||||
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ | |||||
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ | |||||
d3 = GiSimdFmaLane(d3, t2##i, v0, 3); \ | |||||
d4 = GiSimdFmaLane(d4, t2##i, v0, 3); \ | |||||
d5 = GiSimdFmaLane(d5, t2##i, v1, 3); \ | |||||
d6 = GiSimdFmaLane(d6, t2##i, v1, 3); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t3##i, v0, 1); \ | |||||
d2 = GiSimdFmaLane(d2, t3##i, v0, 1); \ | |||||
d3 = GiFmsqLaneQFloat32(d3, t3##i, v1, 0); \ | |||||
d4 = GiSimdFmaLane(d4, t3##i, v1, 0); \ | |||||
d5 = GiFmsqLaneQFloat32(d5, t3##i, v1, 0); \ | |||||
d6 = GiSimdFmaLane(d6, t3##i, v1, 0); \ | |||||
d7 = GiSimdFmaLane(d7, t3##i, v0, 0); \ | |||||
d0 = GiSimdFmaLane(d0, t4##i, v0, 0); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t4##i, v0, 1); \ | |||||
d2 = GiFmsqLaneQFloat32(d2, t4##i, v0, 1); \ | |||||
d3 = GiFmsqLaneQFloat32(d3, t4##i, v1, 1); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t4##i, v1, 1); \ | |||||
d5 = GiFmsqLaneQFloat32(d5, t4##i, v2, 0); \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t4##i, v2, 0); \ | |||||
d1 = d1 + t5##i; \ | d1 = d1 + t5##i; \ | ||||
d2 = d2 - t5##i; \ | d2 = d2 - t5##i; \ | ||||
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ | |||||
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ | |||||
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ | |||||
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ | |||||
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ | |||||
vst1q_f32( \ | |||||
d3 = GiSimdFmaLane(d3, t5##i, v1, 2); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t5##i, v1, 2); \ | |||||
d5 = GiSimdFmaLane(d5, t5##i, v0, 2); \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t5##i, v0, 2); \ | |||||
d7 = GiFmsqLaneQFloat32(d7, t5##i, v0, 0); \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d0); \ | d0); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d1); \ | d1); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d2); \ | d2); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d3); \ | d3); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d4); \ | d4); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d5); \ | d5); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d6); \ | d6); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
@@ -347,7 +347,7 @@ struct OutputTransformF63_NCHW44 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) | ||||
@@ -488,14 +488,14 @@ void winograd_F63_mk4_f_nchw44::output( | |||||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | ||||
"NCHW44 Winograd filter transform requires OC is times of 4"); | "NCHW44 Winograd filter transform requires OC is times of 4"); | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F63_mk4, cb, float, float, bmode, | |||||
nonline_mode); | nonline_mode); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp | |||||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -9,22 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/common/winograd/winograd_helper.h" | #include "src/common/winograd/winograd_helper.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/gi/utils.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | #include "src/fallback/conv_bias/winograd/winograd.h" | ||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4) | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F73_mk4) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | |||||
using namespace fallback; | |||||
namespace { | namespace { | ||||
@@ -51,11 +50,11 @@ struct InputTransformF73_NCHW44 { | |||||
const float* input_ptr = | const float* input_ptr = | ||||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | ||||
for (size_t ih = 0; ih < alpha; ih++) { | for (size_t ih = 0; ih < alpha; ih++) { | ||||
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); | |||||
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); | |||||
UNROLL_CALL_NOWRAPPER(9, cb); | UNROLL_CALL_NOWRAPPER(9, cb); | ||||
#undef cb | #undef cb | ||||
#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||||
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||||
UNROLL_CALL_NOWRAPPER(9, cb); | UNROLL_CALL_NOWRAPPER(9, cb); | ||||
#undef cb | #undef cb | ||||
input_ptr += IW4; | input_ptr += IW4; | ||||
@@ -70,8 +69,9 @@ struct InputTransformF73_NCHW44 { | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | for (int ih = ih0_act; ih < ih1_act; ++ih) { | ||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | for (int iw = iw0_act; iw < iw1_act; ++iw) { | ||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | size_t iho = ih - ih_start, iwo = iw - iw_start; | ||||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||||
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||||
GiStoreFloat32( | |||||
patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -85,14 +85,14 @@ struct InputTransformF73_NCHW44 { | |||||
size_t ICB = IC / pack_size; | size_t ICB = IC / pack_size; | ||||
size_t icb = ic / pack_size; | size_t icb = ic / pack_size; | ||||
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8; | |||||
float32x4_t v0 = vld1q_f32(input_parameters + 0); | |||||
float32x4_t v1 = vld1q_f32(input_parameters + 4); | |||||
float32x4_t v2 = vld1q_f32(input_parameters + 8); | |||||
float32x4_t v3 = vld1q_f32(input_parameters + 12); | |||||
float32x4_t v4 = vld1q_f32(input_parameters + 16); | |||||
float32x4_t v5 = vld1q_f32(input_parameters + 20); | |||||
float32x4_t v6 = vld1q_f32(input_parameters + 24); | |||||
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7, d8; | |||||
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); | |||||
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); | |||||
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); | |||||
GI_FLOAT32_t v3 = GiLoadFloat32(input_parameters + 12); | |||||
GI_FLOAT32_t v4 = GiLoadFloat32(input_parameters + 16); | |||||
GI_FLOAT32_t v5 = GiLoadFloat32(input_parameters + 20); | |||||
GI_FLOAT32_t v6 = GiLoadFloat32(input_parameters + 24); | |||||
//! B | //! B | ||||
//! 1.5 0 0 0 0 0 0 0 0 | //! 1.5 0 0 0 0 0 0 0 0 | ||||
@@ -113,77 +113,77 @@ struct InputTransformF73_NCHW44 { | |||||
// 5.0f, 10.0f, 5.75f, 2.75f, v5 | // 5.0f, 10.0f, 5.75f, 2.75f, v5 | ||||
// 4.25f, 1.75f, 2.0f, 0.0f, v6 | // 4.25f, 1.75f, 2.0f, 0.0f, v6 | ||||
#define cb(i) \ | |||||
d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||||
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||||
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||||
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||||
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||||
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||||
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||||
d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||||
auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \ | |||||
auto t##i##0 = d7; \ | |||||
auto t##i##1 = d7; \ | |||||
auto t##i##2 = d7; \ | |||||
auto t##i##3 = d7; \ | |||||
auto t##i##4 = d7; \ | |||||
auto t##i##5 = d7; \ | |||||
auto t##i##6 = d7; \ | |||||
auto t##i##7 = d7; \ | |||||
t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \ | |||||
t##i##0 = t##i##0 - d1; \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \ | |||||
t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \ | |||||
t##i##7 = t##i##7 - d1; \ | |||||
t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \ | |||||
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \ | |||||
t##i##8 = t##i##8 - d2; \ | |||||
t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \ | |||||
t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \ | |||||
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \ | |||||
t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \ | |||||
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \ | |||||
t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \ | |||||
t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \ | |||||
t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \ | |||||
t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \ | |||||
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \ | |||||
t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \ | |||||
t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \ | |||||
t##i##5 = t##i##5 - d6; \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \ | |||||
t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \ | |||||
t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0); | |||||
#define cb(i) \ | |||||
d0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||||
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||||
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||||
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||||
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||||
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||||
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||||
d7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||||
auto t##i##8 = GiLoadFloat32(patchT + i * alpha * pack_size + 8 * pack_size); \ | |||||
auto t##i##0 = d7; \ | |||||
auto t##i##1 = d7; \ | |||||
auto t##i##2 = d7; \ | |||||
auto t##i##3 = d7; \ | |||||
auto t##i##4 = d7; \ | |||||
auto t##i##5 = d7; \ | |||||
auto t##i##6 = d7; \ | |||||
auto t##i##7 = d7; \ | |||||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d7, v0, 0); \ | |||||
t##i##0 = t##i##0 - d1; \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d1, v0, 0); \ | |||||
t##i##2 = GiSimdFmaLane(t##i##2, d1, v0, 0); \ | |||||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d1, v0, 1); \ | |||||
t##i##4 = GiSimdFmaLane(t##i##4, d1, v0, 1); \ | |||||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d1, v0, 2); \ | |||||
t##i##6 = GiSimdFmaLane(t##i##6, d1, v0, 2); \ | |||||
t##i##7 = t##i##7 - d1; \ | |||||
t##i##8 = GiSimdFmaLane(t##i##8, d1, v0, 0); \ | |||||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 3); \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d2, v1, 0); \ | |||||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d2, v1, 1); \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d2, v1, 2); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d2, v1, 3); \ | |||||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d2, v2, 0); \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d2, v2, 1); \ | |||||
t##i##8 = t##i##8 - d2; \ | |||||
t##i##0 = GiSimdFmaLane(t##i##0, d3, v2, 2); \ | |||||
t##i##1 = GiSimdFmaLane(t##i##1, d3, v2, 3); \ | |||||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d3, v3, 0); \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d3, v2, 0); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d3, v3, 1); \ | |||||
t##i##5 = GiSimdFmaLane(t##i##5, d3, v3, 2); \ | |||||
t##i##6 = GiSimdFmaLane(t##i##6, d3, v3, 3); \ | |||||
t##i##7 = GiSimdFmaLane(t##i##7, d3, v2, 2); \ | |||||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d3, v0, 3); \ | |||||
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 3); \ | |||||
t##i##1 = GiSimdFmaLane(t##i##1, d4, v4, 0); \ | |||||
t##i##2 = GiSimdFmaLane(t##i##2, d4, v4, 1); \ | |||||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v4, 2); \ | |||||
t##i##4 = GiSimdFmaLane(t##i##4, d4, v4, 3); \ | |||||
t##i##5 = GiSimdFmaLane(t##i##5, d4, v5, 0); \ | |||||
t##i##6 = GiSimdFmaLane(t##i##6, d4, v5, 1); \ | |||||
t##i##8 = GiSimdFmaLane(t##i##8, d4, v2, 2); \ | |||||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d5, v2, 2); \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d5, v5, 2); \ | |||||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d5, v5, 3); \ | |||||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d5, v6, 0); \ | |||||
t##i##4 = GiSimdFmaLane(t##i##4, d5, v6, 1); \ | |||||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d5, v5, 2); \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v6, 0); \ | |||||
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v2, 2); \ | |||||
t##i##8 = GiSimdFmaLane(t##i##8, d5, v0, 3); \ | |||||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d6, v0, 0); \ | |||||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d6, v1, 0); \ | |||||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d6, v1, 1); \ | |||||
t##i##3 = GiSimdFmaLane(t##i##3, d6, v1, 0); \ | |||||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d6, v3, 1); \ | |||||
t##i##5 = t##i##5 - d6; \ | |||||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d6, v6, 2); \ | |||||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d6, v2, 2); \ | |||||
t##i##0 = GiSimdFmaLane(t##i##0, d0, v0, 0); | |||||
UNROLL_CALL_RAW(9, cb); | UNROLL_CALL_RAW(9, cb); | ||||
#undef cb | #undef cb | ||||
@@ -198,100 +198,100 @@ struct InputTransformF73_NCHW44 { | |||||
d5 = t7##i; \ | d5 = t7##i; \ | ||||
d6 = t7##i; \ | d6 = t7##i; \ | ||||
d7 = t7##i; \ | d7 = t7##i; \ | ||||
d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ | |||||
d8 = GiFmsqLaneQFloat32(d8, t7##i, v0, 0); \ | |||||
d0 = d0 - t1##i; \ | d0 = d0 - t1##i; \ | ||||
d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ | |||||
d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ | |||||
d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ | |||||
d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ | |||||
d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ | |||||
d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t1##i, v0, 0); \ | |||||
d2 = GiSimdFmaLane(d2, t1##i, v0, 0); \ | |||||
d3 = GiFmsqLaneQFloat32(d3, t1##i, v0, 1); \ | |||||
d4 = GiSimdFmaLane(d4, t1##i, v0, 1); \ | |||||
d5 = GiFmsqLaneQFloat32(d5, t1##i, v0, 2); \ | |||||
d6 = GiSimdFmaLane(d6, t1##i, v0, 2); \ | |||||
d7 = d7 - t1##i; \ | d7 = d7 - t1##i; \ | ||||
d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ | |||||
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ | |||||
d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ | |||||
d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ | |||||
d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ | |||||
d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ | |||||
d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ | |||||
d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ | |||||
d8 = GiSimdFmaLane(d8, t1##i, v0, 0); \ | |||||
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 3); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t2##i, v1, 0); \ | |||||
d2 = GiFmsqLaneQFloat32(d2, t2##i, v1, 1); \ | |||||
d3 = GiSimdFmaLane(d3, t2##i, v1, 2); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t2##i, v1, 3); \ | |||||
d5 = GiFmsqLaneQFloat32(d5, t2##i, v2, 0); \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t2##i, v2, 1); \ | |||||
d8 = d8 - t2##i; \ | d8 = d8 - t2##i; \ | ||||
d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ | |||||
d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ | |||||
d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ | |||||
d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ | |||||
d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ | |||||
d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ | |||||
d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ | |||||
d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ | |||||
d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ | |||||
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ | |||||
d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ | |||||
d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ | |||||
d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ | |||||
d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ | |||||
d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ | |||||
d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ | |||||
d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ | |||||
d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ | |||||
d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ | |||||
d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ | |||||
d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ | |||||
d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ | |||||
d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ | |||||
d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ | |||||
d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ | |||||
d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ | |||||
d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ | |||||
d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ | |||||
d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ | |||||
d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ | |||||
d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ | |||||
d0 = GiSimdFmaLane(d0, t3##i, v2, 2); \ | |||||
d1 = GiSimdFmaLane(d1, t3##i, v2, 3); \ | |||||
d2 = GiFmsqLaneQFloat32(d2, t3##i, v3, 0); \ | |||||
d3 = GiSimdFmaLane(d3, t3##i, v2, 0); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t3##i, v3, 1); \ | |||||
d5 = GiSimdFmaLane(d5, t3##i, v3, 2); \ | |||||
d6 = GiSimdFmaLane(d6, t3##i, v3, 3); \ | |||||
d7 = GiSimdFmaLane(d7, t3##i, v2, 2); \ | |||||
d8 = GiFmsqLaneQFloat32(d8, t3##i, v0, 3); \ | |||||
d0 = GiSimdFmaLane(d0, t4##i, v0, 3); \ | |||||
d1 = GiSimdFmaLane(d1, t4##i, v4, 0); \ | |||||
d2 = GiSimdFmaLane(d2, t4##i, v4, 1); \ | |||||
d3 = GiFmsqLaneQFloat32(d3, t4##i, v4, 2); \ | |||||
d4 = GiSimdFmaLane(d4, t4##i, v4, 3); \ | |||||
d5 = GiSimdFmaLane(d5, t4##i, v5, 0); \ | |||||
d6 = GiSimdFmaLane(d6, t4##i, v5, 1); \ | |||||
d8 = GiSimdFmaLane(d8, t4##i, v2, 2); \ | |||||
d0 = GiFmsqLaneQFloat32(d0, t5##i, v2, 2); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t5##i, v5, 2); \ | |||||
d2 = GiFmsqLaneQFloat32(d2, t5##i, v5, 3); \ | |||||
d3 = GiFmsqLaneQFloat32(d3, t5##i, v6, 0); \ | |||||
d4 = GiSimdFmaLane(d4, t5##i, v6, 1); \ | |||||
d5 = GiFmsqLaneQFloat32(d5, t5##i, v5, 2); \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t5##i, v6, 0); \ | |||||
d7 = GiFmsqLaneQFloat32(d7, t5##i, v2, 2); \ | |||||
d8 = GiSimdFmaLane(d8, t5##i, v0, 3); \ | |||||
d0 = GiFmsqLaneQFloat32(d0, t6##i, v0, 0); \ | |||||
d1 = GiFmsqLaneQFloat32(d1, t6##i, v1, 0); \ | |||||
d2 = GiFmsqLaneQFloat32(d2, t6##i, v1, 1); \ | |||||
d3 = GiSimdFmaLane(d3, t6##i, v1, 0); \ | |||||
d4 = GiFmsqLaneQFloat32(d4, t6##i, v3, 1); \ | |||||
d5 = d5 - t6##i; \ | d5 = d5 - t6##i; \ | ||||
d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ | |||||
d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ | |||||
d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ | |||||
vst1q_f32( \ | |||||
d6 = GiFmsqLaneQFloat32(d6, t6##i, v6, 2); \ | |||||
d8 = GiFmsqLaneQFloat32(d8, t6##i, v2, 2); \ | |||||
d0 = GiSimdFmaLane(d0, t0##i, v0, 0); \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d0); \ | d0); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d1); \ | d1); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d2); \ | d2); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d3); \ | d3); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d4); \ | d4); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d5); \ | d5); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d6); \ | d6); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
d7); \ | d7); \ | ||||
vst1q_f32( \ | |||||
GiStoreFloat32( \ | |||||
input_transform_buf + \ | input_transform_buf + \ | ||||
(8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | (8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | ||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | ||||
@@ -413,7 +413,7 @@ struct OutputTransformF73_NCHW44 { | |||||
} // namespace | } // namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | |||||
namespace fallback { | |||||
namespace winograd { | namespace winograd { | ||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) | MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) | ||||
@@ -554,14 +554,14 @@ void winograd_F73_mk4_f_nchw44::output( | |||||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | ||||
"NCHW44 Winograd filter transform requires OC is times of 4"); | "NCHW44 Winograd filter transform requires OC is times of 4"); | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F73_mk4, cb, float, float, bmode, | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F73_mk4, cb, float, float, bmode, | |||||
nonline_mode); | nonline_mode); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace fallback | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,413 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/intrinsic_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "src/fallback/general_intrinsic/gi_int.h" | |||||
namespace megdnn { | |||||
namespace { | |||||
struct Vld1qF32S { | |||||
static GI_FORCEINLINE GI_FLOAT32_t impl(const float32_t* ptr) { | |||||
return GiLoadFloat32(ptr); | |||||
} | |||||
}; | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||||
#ifdef __GNUC__ | |||||
#ifndef __has_warning | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#else | |||||
#if __has_warning("-Wmaybe-uninitialized") | |||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||||
#endif | |||||
#endif | |||||
#endif | |||||
template < | |||||
int weight_number, int base_offset, int ptr_step, int oc_block, typename Func, | |||||
typename T, typename T2, typename... XT> | |||||
struct LoadHelper { | |||||
static GI_FORCEINLINE void impl(T& weight, T2 ptr, int oc_offset, XT... args); | |||||
}; | |||||
#define WEIGHT_CB(step) \ | |||||
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); | |||||
#define LOAD_HELPER(step) \ | |||||
template < \ | |||||
int base_offset, int ptr_step, typename Func, typename T, typename T2, \ | |||||
typename... XT> \ | |||||
struct LoadHelper<step, base_offset, ptr_step, 0, Func, T, T2, XT...> { \ | |||||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int, XT... args) { \ | |||||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||||
} \ | |||||
} | |||||
LOAD_HELPER(1); | |||||
LOAD_HELPER(2); | |||||
LOAD_HELPER(3); | |||||
LOAD_HELPER(4); | |||||
LOAD_HELPER(5); | |||||
LOAD_HELPER(6); | |||||
LOAD_HELPER(7); | |||||
LOAD_HELPER(8); | |||||
LOAD_HELPER(9); | |||||
LOAD_HELPER(10); | |||||
LOAD_HELPER(11); | |||||
LOAD_HELPER(12); | |||||
LOAD_HELPER(13); | |||||
LOAD_HELPER(14); | |||||
LOAD_HELPER(15); | |||||
LOAD_HELPER(16); | |||||
#undef LOAD_HELPER | |||||
#undef WEIGHT_CB | |||||
///////////////////////////c_dim = 1///////////////////////// | |||||
#define WEIGHT_CB(step) src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); | |||||
#define LOAD_HELPER(step) \ | |||||
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \ | |||||
struct LoadHelper<step, base_offset, ptr_step, 1, Func, T, T2> { \ | |||||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int) { \ | |||||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||||
} \ | |||||
} | |||||
LOAD_HELPER(1); | |||||
LOAD_HELPER(2); | |||||
LOAD_HELPER(3); | |||||
LOAD_HELPER(4); | |||||
LOAD_HELPER(5); | |||||
LOAD_HELPER(6); | |||||
LOAD_HELPER(7); | |||||
LOAD_HELPER(8); | |||||
LOAD_HELPER(9); | |||||
#undef LOAD_HELPER | |||||
#undef WEIGHT_CB | |||||
/////////////////////////c_dim = 2/////////////////////////////// | |||||
#define WEIGHT_CB(step) \ | |||||
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ | |||||
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); | |||||
#define LOAD_HELPER(step) \ | |||||
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \ | |||||
struct LoadHelper<step, base_offset, ptr_step, 2, Func, T, T2> { \ | |||||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int oc_offset) { \ | |||||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||||
} \ | |||||
} | |||||
LOAD_HELPER(1); | |||||
LOAD_HELPER(2); | |||||
LOAD_HELPER(3); | |||||
LOAD_HELPER(4); | |||||
LOAD_HELPER(5); | |||||
LOAD_HELPER(6); | |||||
LOAD_HELPER(7); | |||||
LOAD_HELPER(8); | |||||
#undef LOAD_HELPER | |||||
#undef WEIGHT_CB | |||||
template < | |||||
int weight_number, int base_offset, int ptr_step, int c_dim, typename Func, | |||||
typename T, typename T2> | |||||
GI_FORCEINLINE void load_helper(T& weight, T2 ptr, int oc_offset) { | |||||
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl( | |||||
weight, ptr, oc_offset); | |||||
} | |||||
////////////////////Store_OCX_OW8_Remain///////////////////////// | |||||
template <int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||||
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||||
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||||
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||||
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||||
} | |||||
}; | |||||
template <typename Op, typename T, typename T2, typename T3> | |||||
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { | |||||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||||
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||||
} | |||||
}; | |||||
template <int c_dim, int ow_remain, typename Op, typename T, typename T2> | |||||
GI_FORCEINLINE void store_ocx_ow8_remain_static( | |||||
T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||||
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
#undef cb | |||||
#undef cb2 | |||||
#undef cb_case | |||||
#undef cb_case2 | |||||
#pragma GCC diagnostic pop | |||||
/////////////////////////init_ocx_ow8//////////////////// | |||||
template <typename T> | |||||
struct GiLdqSimd; | |||||
template <> | |||||
struct GiLdqSimd<float> { | |||||
static constexpr int simd_len = 4; | |||||
}; | |||||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||||
struct InitOcxOw8 { | |||||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step); | |||||
}; | |||||
template <int c_dim, BiasMode bias_mode, typename T, typename T2> | |||||
struct InitOcxOw8<c_dim, bias_mode, 0, T, T2> { | |||||
static GI_FORCEINLINE void impl(T&, const T2*, int) {} | |||||
}; | |||||
#define BAIS_INIT_NO_BIAS_C2(step) \ | |||||
c[0][step] = GiBroadcastFloat32(static_cast<T2>(0)); \ | |||||
c[1][step] = GiBroadcastFloat32(static_cast<T2>(0)); | |||||
#define BAIS_INIT_NO_BIAS_C1(step) c[0][step] = GiBroadcastFloat32(static_cast<T2>(0)); | |||||
#define BAIS_INIT_BROADCAST_C2(step) \ | |||||
c[0][step] = GiLoadFloat32(bias_ptr); \ | |||||
c[1][step] = GiLoadFloat32(bias_ptr + oc_step); | |||||
#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = GiLoadFloat32(bias_ptr); | |||||
#define BAIS_INIT_BIAS_C2(step) \ | |||||
c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); \ | |||||
c[1][step] = GiLoadFloat32(bias_ptr + oc_step + step * simd_len); | |||||
#define BAIS_INIT_BIAS_C1(step) c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); | |||||
#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::NO_BIAS, ow_remain, T, T2> { \ | |||||
static GI_FORCEINLINE void impl(T& c, const T2*, int) { \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ | |||||
} \ | |||||
}; \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::BROADCAST_CHANNEL_BIAS, ow_remain, T, T2> { \ | |||||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||||
(void)oc_step; \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ | |||||
} \ | |||||
}; \ | |||||
template <typename T, typename T2> \ | |||||
struct InitOcxOw8<cdim, BiasMode::BIAS, ow_remain, T, T2> { \ | |||||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||||
constexpr int simd_len = GiLdqSimd<T2>::simd_len; \ | |||||
(void)oc_step; \ | |||||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ | |||||
} \ | |||||
}; | |||||
#define INSTANCE_InitOcxOw8_C(ow_remain) \ | |||||
INSTANCE_InitOcxOw8(ow_remain, 2); \ | |||||
INSTANCE_InitOcxOw8(ow_remain, 1); | |||||
INSTANCE_InitOcxOw8_C(1); | |||||
INSTANCE_InitOcxOw8_C(2); | |||||
INSTANCE_InitOcxOw8_C(3); | |||||
INSTANCE_InitOcxOw8_C(4); | |||||
INSTANCE_InitOcxOw8_C(5); | |||||
INSTANCE_InitOcxOw8_C(6); | |||||
INSTANCE_InitOcxOw8_C(7); | |||||
INSTANCE_InitOcxOw8_C(8); | |||||
#undef INSTANCE_InitOcxOw8 | |||||
#undef INSTANCE_InitOcxOw8_C | |||||
#undef BAIS_INIT_BIAS_C1 | |||||
#undef BAIS_INIT_BIAS_C2 | |||||
#undef BAIS_INIT_BROADCAST_C1 | |||||
#undef BAIS_INIT_BROADCAST_C2 | |||||
#undef BAIS_INIT_NO_BIAS_C1 | |||||
#undef BAIS_INIT_NO_BIAS_C2 | |||||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||||
GI_FORCEINLINE void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { | |||||
InitOcxOw8<c_dim, bias_mode, ow_remain, T, T2>::impl(c, bias_ptr, oc_step); | |||||
} | |||||
} // namespace | |||||
} // namespace megdnn | |||||
#undef GI_FORCEINLINE | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,86 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/postprocess_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/basic_types.h" | |||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(fallback_gi_conv_bias_postprocess_helper) | |||||
namespace { | |||||
#define GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||||
_midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ | |||||
switch (_nonline_mode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: { \ | |||||
MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ | |||||
cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} \ | |||||
case param::ConvBias::NonlineMode::RELU: { \ | |||||
MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ | |||||
cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} \ | |||||
case param::ConvBias::NonlineMode::SIGMOID: { \ | |||||
MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ | |||||
cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: { \ | |||||
MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ | |||||
cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
#define GI_DISPATCH_CONV_WINOGRAD_BIAS( \ | |||||
_midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ | |||||
switch (_bmode) { \ | |||||
case BiasMode::BIAS: { \ | |||||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||||
_midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ | |||||
_nonline_mode, __VA_ARGS__) \ | |||||
break; \ | |||||
} \ | |||||
case BiasMode::NO_BIAS: { \ | |||||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||||
_midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \ | |||||
_nonline_mode, __VA_ARGS__) \ | |||||
break; \ | |||||
} \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: { \ | |||||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||||
_midout_tag, cb, 2, _src_type, _dst_type, \ | |||||
BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \ | |||||
break; \ | |||||
} \ | |||||
default: \ | |||||
megdnn_assert(0); \ | |||||
break; \ | |||||
} | |||||
} // namespace |
@@ -0,0 +1,193 @@ | |||||
/** | |||||
* \file dnn/src/fallback/conv_bias/gi/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstring> | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
template <typename ctype, size_t len> | |||||
struct Vector; | |||||
template <> | |||||
struct Vector<float, 4> { | |||||
GI_FLOAT32_t value; | |||||
Vector() {} | |||||
Vector(const float v) { value = GiBroadcastFloat32(v); } | |||||
Vector(const Vector& lr) { value = lr.value; } | |||||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||||
Vector(const GI_FLOAT32_t& v) { value = v; } | |||||
static Vector load(const float* addr) { | |||||
Vector v; | |||||
v.value = GiLoadFloat32(addr); | |||||
return v; | |||||
} | |||||
static void save(float* addr, const Vector& v) { GiStoreFloat32(addr, v.value); } | |||||
void save(float* addr) { save(addr, *this); } | |||||
Vector operator+(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiAddFloat32(value, lr.value); | |||||
return dst; | |||||
} | |||||
Vector& operator+=(const Vector& lr) { | |||||
value = GiAddFloat32(value, lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator-(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiSubtractFloat32(value, lr.value); | |||||
return dst; | |||||
} | |||||
Vector& operator-=(const Vector& lr) { | |||||
value = GiSubtractFloat32(value, lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator*(float lr) { | |||||
Vector dst; | |||||
dst.value = GiMultiplyScalerFloat32(value, lr); | |||||
return dst; | |||||
} | |||||
Vector operator*(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value = GiMultiplyFloat32(value, lr.value); | |||||
return dst; | |||||
} | |||||
Vector& operator*=(const Vector& lr) { | |||||
value = GiMultiplyFloat32(value, lr.value); | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector& lr) { | |||||
value = lr.value; | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector&& lr) { | |||||
value = std::move(lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator-() { | |||||
Vector dst; | |||||
dst.value = -value; | |||||
return dst; | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vector<float, 8> { | |||||
GI_FLOAT32_V2_t value; | |||||
Vector() {} | |||||
Vector(const float v) { | |||||
value.val[0] = GiBroadcastFloat32(v); | |||||
value.val[1] = GiBroadcastFloat32(v); | |||||
} | |||||
Vector(const Vector& lr) { value = lr.value; } | |||||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||||
Vector(const GI_FLOAT32_V2_t& v) { value = v; } | |||||
static Vector load(const float* addr) { | |||||
Vector v; | |||||
#if defined(GI_TEST_NAIVE) | |||||
v.value.val[0] = GiLoadFloat32(addr); | |||||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
v.value = vld1q_f32_x2(addr); | |||||
#else | |||||
v.value.val[0] = GiLoadFloat32(addr); | |||||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||||
#endif | |||||
return v; | |||||
} | |||||
static void save(float* addr, const Vector& v) { | |||||
#if defined(GI_TEST_NAIVE) | |||||
GiStoreFloat32(addr, v.value.val[0]); | |||||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
vst1q_f32_x2(addr, v.value); | |||||
#else | |||||
GiStoreFloat32(addr, v.value.val[0]); | |||||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||||
#endif | |||||
} | |||||
void save(float* addr) { save(addr, *this); } | |||||
Vector operator+(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||||
dst.value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||||
return dst; | |||||
} | |||||
Vector& operator+=(const Vector& lr) { | |||||
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||||
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||||
return *this; | |||||
} | |||||
Vector& add(const Vector& lr) { | |||||
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||||
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||||
return *this; | |||||
} | |||||
Vector operator-(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); | |||||
dst.value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); | |||||
return dst; | |||||
} | |||||
Vector& operator-=(const Vector& lr) { | |||||
value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); | |||||
value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); | |||||
return *this; | |||||
} | |||||
Vector operator*(float lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiMultiplyScalerFloat32(value.val[0], lr); | |||||
dst.value.val[1] = GiMultiplyScalerFloat32(value.val[1], lr); | |||||
return dst; | |||||
} | |||||
//! val + lr * n | |||||
Vector& mla(const Vector& lr, float n) { | |||||
value.val[0] = GiMultiplyAddScalarFloat32(value.val[0], lr.value.val[0], n); | |||||
value.val[1] = GiMultiplyAddScalarFloat32(value.val[1], lr.value.val[1], n); | |||||
return *this; | |||||
} | |||||
Vector operator*(const Vector& lr) { | |||||
Vector dst; | |||||
dst.value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); | |||||
dst.value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); | |||||
return dst; | |||||
} | |||||
Vector& operator*=(const Vector& lr) { | |||||
value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); | |||||
value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector& lr) { | |||||
value = lr.value; | |||||
return *this; | |||||
} | |||||
Vector& operator=(const Vector&& lr) { | |||||
value = std::move(lr.value); | |||||
return *this; | |||||
} | |||||
Vector operator-() { | |||||
Vector dst; | |||||
dst.value.val[0] = -value.val[0]; | |||||
dst.value.val[1] = -value.val[1]; | |||||
return dst; | |||||
} | |||||
}; | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -16,6 +16,7 @@ | |||||
#include "src/fallback/conv_bias/algos.h" | #include "src/fallback/conv_bias/algos.h" | ||||
#include "src/fallback/conv_bias/conv1x1/algos.h" | #include "src/fallback/conv_bias/conv1x1/algos.h" | ||||
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | #include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | ||||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||||
#include "src/fallback/conv_bias/im2col/algos.h" | #include "src/fallback/conv_bias/im2col/algos.h" | ||||
#include "src/fallback/convolution/opr_impl.h" | #include "src/fallback/convolution/opr_impl.h" | ||||
#include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
@@ -34,6 +35,14 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
namespace { | |||||
//! TODO: imp is_fallback_exclude_gi_or_naive | |||||
bool is_naive(const detail::Algorithm* algo) { | |||||
return algo->handle_type() == Handle::HandleType::NAIVE; | |||||
} | |||||
} // anonymous namespace | |||||
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | ||||
switch (format) { | switch (format) { | ||||
case param::ConvBias::Format::NCHW44: | case param::ConvBias::Format::NCHW44: | ||||
@@ -73,16 +82,95 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
SmallVector<AlgoBase*> m_all_algos; | SmallVector<AlgoBase*> m_all_algos; | ||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_gi_winograd_algos; | |||||
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | |||||
AlgoF32DirectNCHW44 f32_direct_nchw44; | |||||
AlgoF32Direct f32_direct; | |||||
AlgoF32DirectStride2 f32_direct_stride2; | |||||
AlgoF32DirectStride1 f32_direct_stride1; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
// fallback gi fp32 algo | |||||
m_all_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||||
m_all_algos.emplace_back(&f32_chanel_wise_nchw44); | |||||
m_all_algos.emplace_back(&f32_direct_nchw44); | |||||
m_all_algos.emplace_back(&f32_direct_stride1); | |||||
m_all_algos.emplace_back(&f32_direct_stride2); | |||||
m_all_algos.emplace_back(&f32_direct); | |||||
static CpuOprDelegationStorage<2> storage; | |||||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||||
using MatmulFormat = param::MatrixMul::Format; | |||||
auto&& matmul_algos = | |||||
static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (is_naive(algo)) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
//! uncomment this when low precision mode is done | |||||
#if 0 | |||||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
#endif | |||||
} | |||||
} | |||||
//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw | |||||
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd | |||||
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type( | |||||
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (is_naive(algo)) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
} | |||||
} | |||||
for (auto&& algo : m_gi_winograd_algos) { | |||||
m_all_algos.emplace_back(algo); | |||||
} | |||||
// end fallback gi fp32 algo | |||||
refhold.emplace_back(new AlgoConv1x1Gemv()); | refhold.emplace_back(new AlgoConv1x1Gemv()); | ||||
m_all_algos.emplace_back(refhold.back().get()); | m_all_algos.emplace_back(refhold.back().get()); | ||||
static CpuOprDelegationStorage<> storage; | |||||
auto matmul_opr = storage.get<MatrixMul>(); | |||||
auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||||
->get_all_packed_algo(); | |||||
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||||
->get_all_packed_algo(); | |||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | ||||
@@ -226,6 +226,20 @@ public: | |||||
FB_CONV1x1, | FB_CONV1x1, | ||||
FB_CONV1x1_GEMV, | FB_CONV1x1_GEMV, | ||||
FB_IM2COL, | FB_IM2COL, | ||||
GI_COMMON_WINOGRAD_F23_4X4_FP32, | |||||
GI_COMMON_WINOGRAD_F63_FP32, | |||||
GI_COMMON_WINOGRAD_F63_4X4_FP32, | |||||
GI_COMMON_WINOGRAD_F54_FP32, | |||||
GI_COMMON_WINOGRAD_F45_FP32, | |||||
GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||||
GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||||
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||||
GI_COMMON_DIRECT_FP32, | |||||
GI_COMMON_DIRECT_STRD1_FP32, | |||||
GI_COMMON_DIRECT_STRD2_FP32, | |||||
GI_COMMON_DIRECT_NCHW44_FP32, | |||||
GI_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||||
GI_COMMON_CHWNWISE_NCHW44_F32, | |||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
X86_DIRECT = 1 << 8, | X86_DIRECT = 1 << 8, | ||||
@@ -248,20 +262,6 @@ public: | |||||
ARM_COMMON_DIRECT_STRD1_FP16, | ARM_COMMON_DIRECT_STRD1_FP16, | ||||
ARM_COMMON_CHWNWISE_NCHW88_F16, | ARM_COMMON_CHWNWISE_NCHW88_F16, | ||||
ARM_COMMON_DIRECT_NCHW88_FP16, | ARM_COMMON_DIRECT_NCHW88_FP16, | ||||
ARM_COMMON_WINOGRAD_F23_4X4_FP32, | |||||
ARM_COMMON_WINOGRAD_F63_FP32, | |||||
ARM_COMMON_WINOGRAD_F63_4X4_FP32, | |||||
ARM_COMMON_WINOGRAD_F54_FP32, | |||||
ARM_COMMON_WINOGRAD_F45_FP32, | |||||
ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||||
ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||||
ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||||
ARM_COMMON_DIRECT_FP32, | |||||
ARM_COMMON_DIRECT_STRD1_FP32, | |||||
ARM_COMMON_DIRECT_STRD2_FP32, | |||||
ARM_COMMON_DIRECT_NCHW44_FP32, | |||||
ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||||
ARM_COMMON_CHWNWISE_NCHW44_F32, | |||||
ARM_COMMON_DIRECT_STRD1_S8, | ARM_COMMON_DIRECT_STRD1_S8, | ||||
ARM_COMMON_DIRECT_STRD2_S8, | ARM_COMMON_DIRECT_STRD2_S8, | ||||
ARM_COMMON_DIRECT_NCHW44, | ARM_COMMON_DIRECT_NCHW44, | ||||
@@ -383,6 +383,23 @@ private: | |||||
class AlgoWinogradF32_4x4; | class AlgoWinogradF32_4x4; | ||||
class AlgoWinogradQS8; | class AlgoWinogradQS8; | ||||
class AlgoWinogradQS8_8x8; | class AlgoWinogradQS8_8x8; | ||||
class AlgoFP32WinogradF23_4x4; | |||||
class AlgoFP32WinogradF63; | |||||
class AlgoFP32WinogradF63_4x4; | |||||
class AlgoFP32WinogradF54; | |||||
class AlgoFP32WinogradF45; | |||||
class AlgoFP32WinogradF23_4x4_NCHW44; | |||||
class AlgoFP32WinogradF63_4x4_NCHW44; | |||||
class AlgoFP32WinogradF73_4x4_NCHW44; | |||||
class AlgoF32Direct; | |||||
class AlgoF32DirectStride1; | |||||
class AlgoF32DirectStride2; | |||||
class AlgoF32DirectNCHWNCHW44; | |||||
class AlgoF32ChannelWiseNCHW44; | |||||
class AlgoF32DirectNCHW44; | |||||
class AlgoPack; | class AlgoPack; | ||||
NCBKernSizeParam m_prev_selected_algo_sizep; | NCBKernSizeParam m_prev_selected_algo_sizep; | ||||
@@ -81,23 +81,6 @@ TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
#define CONV_BIAS_MATMUL_QU8_MODE(MODE) \ | #define CONV_BIAS_MATMUL_QU8_MODE(MODE) \ | ||||
using namespace conv_bias; \ | using namespace conv_bias; \ | ||||
std::vector<TestArg> args = get_quantized_args_with_nlmode(MODE); \ | std::vector<TestArg> args = get_quantized_args_with_nlmode(MODE); \ | ||||
@@ -1015,14 +998,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_4x4) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); | |||||
#else | |||||
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); | |||||
#endif | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); | benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); | ||||
@@ -1031,14 +1006,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_4x4) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); | |||||
#else | |||||
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); | |||||
#endif | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4); | benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4); | ||||
@@ -1212,30 +1179,10 @@ void benchmark_winograd_nchw_vs_nchw44( | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); | |||||
#else | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); | |||||
#endif | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); | |||||
#else | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); | |||||
#endif | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd_nchw_vs_nchw44( | benchmark_winograd_nchw_vs_nchw44( | ||||
"AARCH64_F32_MK4_4x16:4:6", "ARM_COMMON_F32_GEMV_MK4:4:7", handle()); | |||||
"AARCH64_F32_MK4_4x16:4:6", "FB_GI_F32_GEMV_MK4:4:7", handle()); | |||||
#else | #else | ||||
benchmark_winograd_nchw_vs_nchw44( | benchmark_winograd_nchw_vs_nchw44( | ||||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:7", handle()); | "ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:7", handle()); | ||||
@@ -1609,156 +1556,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) { | |||||
computations / used0, used1, computations / used1, used1 / used0); | computations / used0, used1, computations / used1, used1 / used0); | ||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||||
// have to remove preferred restrict in usable func before run the benchmark | |||||
using namespace conv_bias; | |||||
param::ConvBias param; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.nonlineMode = NonlineMode::RELU; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
constexpr size_t RUN = 50; | |||||
Benchmarker<ConvBias> benchmark0(handle()); | |||||
benchmark0.set_display(false); | |||||
benchmark0.set_param(param); | |||||
benchmark0.set_times(RUN); | |||||
benchmark0.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | |||||
opr->param() = param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
Benchmarker<ConvBias> benchmark1(handle()); | |||||
benchmark1.set_display(false); | |||||
benchmark1.set_param(param); | |||||
benchmark1.set_times(RUN); | |||||
benchmark1.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout( | |||||
{{1, group * 4, h, w}, dtype::Int8()}, | |||||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto used0 = benchmark0.exec( | |||||
{{1, group * 4, h, w}, | |||||
{group * 4, 1, 1, kernel, kernel}, | |||||
{1, group * 4, 1, 1}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
auto used1 = benchmark1.exec( | |||||
{{1, group, h, w, 4}, | |||||
{group, 1, 1, kernel, kernel, 4}, | |||||
{1, group, 1, 1, 4}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||||
"nchw44: " | |||||
"%f ms %f GFlops " | |||||
"speedup: %f\n", | |||||
group, h, w, kernel, used0, computations / used0, used1, | |||||
computations / used1, used0 / used1); | |||||
}; | |||||
for (size_t group : {8, 16, 32, 64}) { | |||||
for (size_t kerenl : {2, 3, 5}) { | |||||
run(group, 112, 112, kerenl); | |||||
run(group, 56, 56, kerenl); | |||||
run(group, 48, 48, kerenl); | |||||
run(group, 28, 28, kerenl); | |||||
run(group, 14, 14, kerenl); | |||||
} | |||||
} | |||||
run(8, 112, 112, 3); | |||||
run(32, 56, 56, 3); | |||||
run(64, 28, 28, 3); | |||||
run(128, 14, 14, 3); | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||||
// have to remove preferred restrict in usable func before run the benchmark | |||||
using namespace conv_bias; | |||||
param::ConvBias param; | |||||
param.stride_h = 2; | |||||
param.stride_w = 2; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.nonlineMode = NonlineMode::RELU; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
constexpr size_t RUN = 50; | |||||
Benchmarker<ConvBias> benchmark0(handle()); | |||||
benchmark0.set_display(false); | |||||
benchmark0.set_param(param); | |||||
benchmark0.set_times(RUN); | |||||
benchmark0.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | |||||
opr->param() = param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
Benchmarker<ConvBias> benchmark1(handle()); | |||||
benchmark1.set_display(false); | |||||
benchmark1.set_param(param); | |||||
benchmark1.set_times(RUN); | |||||
benchmark1.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout( | |||||
{{1, group * 4, h, w}, dtype::Int8()}, | |||||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto used0 = benchmark0.exec( | |||||
{{1, group * 4, h, w}, | |||||
{group * 4, 1, 1, kernel, kernel}, | |||||
{1, group * 4, 1, 1}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
auto used1 = benchmark1.exec( | |||||
{{1, group, h, w, 4}, | |||||
{group, 1, 1, kernel, kernel, 4}, | |||||
{1, group, 1, 1, 4}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||||
"nchw44: " | |||||
"%f ms %f GFlops " | |||||
"speedup: %f\n", | |||||
group, h, w, kernel, used0, computations / used0, used1, | |||||
computations / used1, used0 / used1); | |||||
}; | |||||
for (size_t group : {8, 16, 32, 64}) { | |||||
for (size_t kerenl : {2, 3, 5}) { | |||||
run(group, 112, 112, kerenl); | |||||
run(group, 56, 56, kerenl); | |||||
run(group, 48, 48, kerenl); | |||||
run(group, 28, 28, kerenl); | |||||
run(group, 14, 14, kerenl); | |||||
} | |||||
} | |||||
run(8, 112, 112, 3); | |||||
run(32, 56, 56, 3); | |||||
run(64, 28, 28, 3); | |||||
run(128, 14, 14, 3); | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | ||||
// have to remove preferred restrict in usable func before run the benchmark | // have to remove preferred restrict in usable func before run the benchmark | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -303,84 +303,6 @@ void checker_conv_bias_int8x8x32_multi( | |||||
} | } | ||||
} | } | ||||
/**********************************F32 direct************************/ | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { | |||||
check_conv_bias( | |||||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), | |||||
"F32DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||||
//! k=7 s=1 | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), handle(), | |||||
"F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { | |||||
check_conv_bias( | |||||
get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(), | |||||
"F32STRD1"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||||
check_conv_bias( | |||||
get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(), | |||||
"F32STRD2"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args( | |||||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, | |||||
true), | |||||
handle(), "F32_CONV_NCHW_NCHW44"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) { | |||||
check_conv_bias( | |||||
get_nchw44_conv_bias_args( | |||||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, | |||||
true), | |||||
handle(), "F32_CONV_NCHW_NCHW44"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
/**********************************F16 direct************************/ | /**********************************F16 direct************************/ | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { | ||||
@@ -787,50 +709,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd( | |||||
"4:2:32", checker, args, param::MatrixMul::Format::MK4, | |||||
param::ConvBias::Format::NCHW44); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(3); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("1:6:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd( | |||||
"4:6:16", checker, args, param::MatrixMul::Format::MK4, | |||||
param::ConvBias::Format::NCHW44); | |||||
} | |||||
//! uncomment it when low precision mode is ok | //! uncomment it when low precision mode is ok | ||||
#if 0 | #if 0 | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | ||||
@@ -853,22 +731,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCE | |||||
} | } | ||||
#endif | #endif | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(4); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("1:5:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(5); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("1:4:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -81,207 +81,6 @@ void benchmark_impl( | |||||
} | } | ||||
} // namespace | } // namespace | ||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, H, W}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||||
std::string algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, H, W}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||||
std::string algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 2; | |||||
param.stride_w = 2; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group, size_t P, size_t S) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||||
std::string algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
@@ -20,91 +20,7 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd( | |||||
"4:2:32", checker, args, param::MatrixMul::Format::MK4, | |||||
param::ConvBias::Format::NCHW44); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(3); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("1:6:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = | |||||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd( | |||||
"4:6:16", checker, args, param::MatrixMul::Format::MK4, | |||||
param::ConvBias::Format::NCHW44); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(4); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("1:5:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(5); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
check_winograd("1:4:32", checker, args); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> nchw44_args = | |||||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
auto run = [&checker]( | |||||
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype, | |||||
DType C_dtype, DType D_dtype, const float eps) { | |||||
for (auto&& arg : args) { | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
}; | |||||
//! uncomment this when low precision mode is ok | |||||
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), | |||||
// dtype::Float32(), dtype::Float32(), 1e-2f); | |||||
//! remove this when low precision mode is ok | |||||
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), | |||||
dtype::Float32(), 1e-3f); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1_WEIGHT_PREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1_WEIGHT_PREPROCESS) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -286,30 +286,6 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||||
run(M, K, N); | run(M, K, N); | ||||
} | } | ||||
TEST_F(ARM_COMMON, FP32_GEMV_MK4) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4")); | |||||
checker.set_epsilon(1e-2); | |||||
auto run = [&](size_t M, size_t K) { | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
B = TensorShape{K / 4, 1, 4}; | |||||
checker.set_param(param).execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {4, 16, 128, 1024}) | |||||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||||
run(M, K); | |||||
} | |||||
TEST_F(ARM_COMMON, MATRIX_MUL_RECORD) { | TEST_F(ARM_COMMON, MATRIX_MUL_RECORD) { | ||||
TaskRecordChecker<MatrixMul> checker(0); | TaskRecordChecker<MatrixMul> checker(0); | ||||
checker.set_epsilon(1e-2); | checker.set_epsilon(1e-2); | ||||
@@ -117,6 +117,30 @@ TEST_F(FALLBACK, CONV_BIAS_FORWARD_RECORD) { | |||||
} | } | ||||
} | } | ||||
TEST_F(FALLBACK, FP32_GEMV_MK4_GI) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_GI_F32_GEMV_MK4")); | |||||
checker.set_epsilon(1e-2); | |||||
auto run = [&](size_t M, size_t K) { | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
B = TensorShape{K / 4, 1, 4}; | |||||
checker.set_param(param).execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {4, 16, 128, 1024}) | |||||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||||
run(M, K); | |||||
} | |||||
std::vector<conv_bias::TestArg> get_conv_bias_args( | std::vector<conv_bias::TestArg> get_conv_bias_args( | ||||
std::vector<size_t> kernel, std::vector<size_t> padv, | std::vector<size_t> kernel, std::vector<size_t> padv, | ||||
std::vector<param::ConvBias::NonlineMode> nlmodev, std::vector<size_t> stridev, | std::vector<param::ConvBias::NonlineMode> nlmodev, std::vector<size_t> stridev, | ||||
@@ -257,6 +281,189 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | |||||
dtype::Float32{}, dtype::Float32{}, "FALLBACK_NAIVE"); | dtype::Float32{}, dtype::Float32{}, "FALLBACK_NAIVE"); | ||||
} | } | ||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S2) { | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args( | |||||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, | |||||
true), | |||||
handle(), "F32_CONV_NCHW_NCHW44"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S1) { | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args( | |||||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, | |||||
true), | |||||
handle(), "F32_CONV_NCHW_NCHW44"); | |||||
} | |||||
std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||||
bool no_full_bias) { | |||||
using namespace conv_bias; | |||||
using Param = param::ConvBias; | |||||
using NLMode = param::ConvBias::NonlineMode; | |||||
std::vector<TestArg> args; | |||||
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, | |||||
size_t stride, NLMode nlmode, bool pad) { | |||||
Param param; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
if (pad) { | |||||
param.pad_h = kernel / 2; | |||||
param.pad_w = kernel / 2; | |||||
} else { | |||||
param.pad_h = 0; | |||||
param.pad_w = 0; | |||||
} | |||||
param.nonlineMode = nlmode; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w, 4}, | |||||
TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{}); | |||||
if (!no_bias) { | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w, 4}, | |||||
TensorShape{group, 1, 1, kernel, kernel, 4}, | |||||
TensorShape{1, group, 1, 1, 4}); | |||||
} | |||||
if (!no_full_bias) { | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w, 4}, | |||||
TensorShape{group, 1, 1, kernel, kernel, 4}, | |||||
TensorShape{ | |||||
n, group, (h + 2 * param.pad_w - kernel) / stride + 1, | |||||
(w + 2 * param.pad_w - kernel) / stride + 1, 4}); | |||||
} | |||||
}; | |||||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||||
if (!no_nonlinemode) { | |||||
nonlinemode.emplace_back(NLMode::RELU); | |||||
nonlinemode.emplace_back(NLMode::H_SWISH); | |||||
} | |||||
for (size_t n : {1, 2}) { | |||||
for (auto nlmode : nonlinemode) { | |||||
for (bool pad : {true}) { | |||||
for (size_t group : {1, 2, 4, 7, 16}) { | |||||
for (size_t size : {4, 6, 7, 9, 20}) { | |||||
for (size_t kern : kernel) { | |||||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
for (bool pad : {false}) { | |||||
for (size_t group : {1, 2, 7, 16}) { | |||||
for (size_t size : {7, 9, 20}) { | |||||
for (size_t kern : kernel) { | |||||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return args; | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||||
check_conv_bias( | |||||
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), | |||||
"F32_CHANNEL_WISE_NCHW44"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K7) { | |||||
//! k=7 s=1 | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args( | |||||
{7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K2K3) { | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args( | |||||
{2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K5) { | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S2) { | |||||
check_conv_bias( | |||||
conv_bias::get_nchw44_conv_bias_args( | |||||
{2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), | |||||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32) { | |||||
check_conv_bias( | |||||
conv_bias::get_conv_bias_args( | |||||
{1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||||
handle(), "F32DIRECT"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR2) { | |||||
check_conv_bias( | |||||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||||
handle(), "F32STRD2"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR1) { | |||||
check_conv_bias( | |||||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||||
handle(), "F32STRD1"); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_PREPROCESS_NCHW44) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> nchw44_args = conv_bias::get_nchw44_conv_bias_args( | |||||
{3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
auto run = [&checker]( | |||||
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype, | |||||
DType C_dtype, DType D_dtype, const float eps) { | |||||
for (auto&& arg : args) { | |||||
checker.set_dtype(0, A_dtype) | |||||
.set_dtype(1, B_dtype) | |||||
.set_dtype(2, C_dtype) | |||||
.set_dtype(4, D_dtype) | |||||
.set_epsilon(eps) | |||||
.set_param(arg.param) | |||||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
}; | |||||
//! uncomment this when low precision mode is ok | |||||
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), | |||||
// dtype::Float32(), dtype::Float32(), 1e-2f); | |||||
//! remove this when low precision mode is ok | |||||
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), | |||||
dtype::Float32(), 1e-3f); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { | TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
@@ -273,6 +480,422 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { | |||||
} | } | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
namespace { | |||||
void benchmark_impl( | |||||
const param::ConvBias param, | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>>& shapes_and_computation, | |||||
const std::string algo_name, size_t RUNS, | |||||
TaskExecutorConfig&& multi_thread_config, | |||||
TaskExecutorConfig&& single_thread_config, std::vector<DType>& data_type) { | |||||
std::vector<float> multi_thread_times, single_thread_times; | |||||
{ | |||||
auto multi_thread_hanle = create_cpu_handle(0, true, &multi_thread_config); | |||||
auto benchmarker = Benchmarker<ConvBias>(multi_thread_hanle.get()); | |||||
benchmarker.set_times(RUNS) | |||||
.set_display(false) | |||||
.set_param(param) | |||||
.set_dtype(0, data_type[0]) | |||||
.set_dtype(1, data_type[1]) | |||||
.set_dtype(2, data_type[2]) | |||||
.set_dtype(4, data_type[3]) | |||||
.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||||
for (auto shape : shapes_and_computation) { | |||||
multi_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); | |||||
} | |||||
} | |||||
{ | |||||
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); | |||||
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | |||||
benchmarker.set_times(RUNS) | |||||
.set_display(false) | |||||
.set_param(param) | |||||
.set_dtype(0, data_type[0]) | |||||
.set_dtype(1, data_type[1]) | |||||
.set_dtype(2, data_type[2]) | |||||
.set_dtype(4, data_type[3]) | |||||
.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||||
for (auto shape : shapes_and_computation) { | |||||
single_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); | |||||
} | |||||
} | |||||
printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); | |||||
printf("core_ids:"); | |||||
for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { | |||||
printf("%zu ", multi_thread_config.affinity_core_set[i]); | |||||
} | |||||
printf(", Single thread core_id %zu\n", single_thread_config.affinity_core_set[0]); | |||||
for (size_t i = 0; i < shapes_and_computation.size(); i++) { | |||||
auto shapes = shapes_and_computation[i]; | |||||
printf("Bench case: "); | |||||
for (auto&& shape : shapes.first) { | |||||
printf("%s ", shape.to_string().c_str()); | |||||
} | |||||
float computations = shapes.second; | |||||
printf("%zu threads gflops: %f,\n single thread gflops: " | |||||
"%f. spead up = %f, speedup/cores=%f\n", | |||||
multi_thread_config.nr_thread, computations / multi_thread_times[i], | |||||
computations / single_thread_times[i], | |||||
single_thread_times[i] / multi_thread_times[i], | |||||
single_thread_times[i] / multi_thread_times[i] / | |||||
multi_thread_config.nr_thread); | |||||
} | |||||
} | |||||
} // namespace | |||||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, H, W}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||||
std::string algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32DIRECT"; | |||||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR1) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, H, W}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||||
std::string algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32STRD1"; | |||||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR2) { | |||||
constexpr size_t RUNS = 50; | |||||
param::ConvBias param; | |||||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.stride_h = 2; | |||||
param.stride_w = 2; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group, size_t P, size_t S) { | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, | |||||
{group, OC / group, IC / group, FS, FS}, | |||||
{1, OC, 1, 1}, | |||||
{}, | |||||
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; | |||||
TensorShape dst{N, OC, H, W}; | |||||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||||
dst.total_nr_elems()) * | |||||
1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||||
std::string algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | |||||
std::vector<DType> data_type = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
shapes_and_computation.clear(); | |||||
algo_name = "F32STRD2"; | |||||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | |||||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||||
data_type); | |||||
benchmark_impl( | |||||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||||
data_type); | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||||
// have to remove preferred restrict in usable func before run the benchmark | |||||
using namespace conv_bias; | |||||
param::ConvBias param; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.nonlineMode = NonlineMode::RELU; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
constexpr size_t RUN = 50; | |||||
Benchmarker<ConvBias> benchmark0(handle()); | |||||
benchmark0.set_display(false); | |||||
benchmark0.set_param(param); | |||||
benchmark0.set_times(RUN); | |||||
benchmark0.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | |||||
opr->param() = param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
Benchmarker<ConvBias> benchmark1(handle()); | |||||
benchmark1.set_display(false); | |||||
benchmark1.set_param(param); | |||||
benchmark1.set_times(RUN); | |||||
benchmark1.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout( | |||||
{{1, group * 4, h, w}, dtype::Int8()}, | |||||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto used0 = benchmark0.exec( | |||||
{{1, group * 4, h, w}, | |||||
{group * 4, 1, 1, kernel, kernel}, | |||||
{1, group * 4, 1, 1}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
auto used1 = benchmark1.exec( | |||||
{{1, group, h, w, 4}, | |||||
{group, 1, 1, kernel, kernel, 4}, | |||||
{1, group, 1, 1, 4}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||||
"nchw44: " | |||||
"%f ms %f GFlops " | |||||
"speedup: %f\n", | |||||
group, h, w, kernel, used0, computations / used0, used1, | |||||
computations / used1, used0 / used1); | |||||
}; | |||||
for (size_t group : {8, 16, 32, 64}) { | |||||
for (size_t kerenl : {2, 3, 5}) { | |||||
run(group, 112, 112, kerenl); | |||||
run(group, 56, 56, kerenl); | |||||
run(group, 48, 48, kerenl); | |||||
run(group, 28, 28, kerenl); | |||||
run(group, 14, 14, kerenl); | |||||
} | |||||
} | |||||
run(8, 112, 112, 3); | |||||
run(32, 56, 56, 3); | |||||
run(64, 28, 28, 3); | |||||
run(128, 14, 14, 3); | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||||
// have to remove preferred restrict in usable func before run the benchmark | |||||
using namespace conv_bias; | |||||
param::ConvBias param; | |||||
param.stride_h = 2; | |||||
param.stride_w = 2; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.nonlineMode = NonlineMode::RELU; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
constexpr size_t RUN = 50; | |||||
Benchmarker<ConvBias> benchmark0(handle()); | |||||
benchmark0.set_display(false); | |||||
benchmark0.set_param(param); | |||||
benchmark0.set_times(RUN); | |||||
benchmark0.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||||
auto opr = handle()->create_operator<ConvBias>(); | |||||
opr->param() = param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
Benchmarker<ConvBias> benchmark1(handle()); | |||||
benchmark1.set_display(false); | |||||
benchmark1.set_param(param); | |||||
benchmark1.set_times(RUN); | |||||
benchmark1.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout( | |||||
{{1, group * 4, h, w}, dtype::Int8()}, | |||||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
auto used0 = benchmark0.exec( | |||||
{{1, group * 4, h, w}, | |||||
{group * 4, 1, 1, kernel, kernel}, | |||||
{1, group * 4, 1, 1}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
auto used1 = benchmark1.exec( | |||||
{{1, group, h, w, 4}, | |||||
{group, 1, 1, kernel, kernel, 4}, | |||||
{1, group, 1, 1, 4}, | |||||
{}, | |||||
{}}) / | |||||
RUN; | |||||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||||
"nchw44: " | |||||
"%f ms %f GFlops " | |||||
"speedup: %f\n", | |||||
group, h, w, kernel, used0, computations / used0, used1, | |||||
computations / used1, used0 / used1); | |||||
}; | |||||
for (size_t group : {8, 16, 32, 64}) { | |||||
for (size_t kerenl : {2, 3, 5}) { | |||||
run(group, 112, 112, kerenl); | |||||
run(group, 56, 56, kerenl); | |||||
run(group, 48, 48, kerenl); | |||||
run(group, 28, 28, kerenl); | |||||
run(group, 14, 14, kerenl); | |||||
} | |||||
} | |||||
run(8, 112, 112, 3); | |||||
run(32, 56, 56, 3); | |||||
run(64, 28, 28, 3); | |||||
run(128, 14, 14, 3); | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { | TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { | ||||
constexpr size_t RUNS = 10; | constexpr size_t RUNS = 10; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
@@ -320,6 +943,164 @@ TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_4x4) { | |||||
#if MEGDNN_AARCH64 | |||||
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); | |||||
#elif MEGDNN_ARMV7 | |||||
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); | |||||
#else | |||||
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:2", handle(), 3, 4); | |||||
#endif | |||||
} | |||||
void benchmark_winograd_nchw_vs_nchw44( | |||||
const char* algo_name0, const char* algo_name1, Handle* handle) { | |||||
using namespace conv_bias; | |||||
using NLMode = param::ConvBias::NonlineMode; | |||||
std::vector<conv_bias::TestArg> args_nchw44; | |||||
std::vector<conv_bias::TestArg> args_nchw; | |||||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t group, | |||||
NLMode nlmode) { | |||||
param::ConvBias param; | |||||
param.format = param::ConvBias::Format::NCHW44; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
param.nonlineMode = nlmode; | |||||
if (group == 1) { | |||||
param.sparse = param::ConvBias::Sparse::DENSE; | |||||
args_nchw44.emplace_back( | |||||
param, TensorShape{n, ic / 4, h, w, 4}, | |||||
TensorShape{oc / 4, ic / 4, 3, 3, 4, 4}, TensorShape{}); | |||||
param.format = param::ConvBias::Format::NCHW; | |||||
args_nchw.emplace_back( | |||||
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, 3, 3}, | |||||
TensorShape{}); | |||||
} else { | |||||
auto oc_per_group = oc / group; | |||||
auto ic_per_group = ic / group; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
args_nchw44.emplace_back( | |||||
param, TensorShape{n, ic_per_group / 4, h, w, 4}, | |||||
TensorShape{group, oc_per_group / 4, ic_per_group / 4, 3, 3, 4, 4}, | |||||
TensorShape{}); | |||||
param.format = param::ConvBias::Format::NCHW; | |||||
args_nchw.emplace_back( | |||||
param, TensorShape{n, ic, h, w}, | |||||
TensorShape{group, oc_per_group, ic_per_group, 3, 3}, | |||||
TensorShape{}); | |||||
} | |||||
}; | |||||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||||
for (auto nlmode : nonlinemode) | |||||
for (size_t n : {1}) | |||||
for (size_t group = 1; group <= 1; ++group) { | |||||
pack(n, 512, 512, 15, 15, group, nlmode); | |||||
pack(n, 512, 256, 15, 15, group, nlmode); | |||||
pack(n, 256, 256, 29, 29, group, nlmode); | |||||
pack(n, 256, 128, 29, 29, group, nlmode); | |||||
pack(n, 128, 128, 57, 57, group, nlmode); | |||||
pack(n, 128, 64, 57, 57, group, nlmode); | |||||
pack(n, 24, 24, 224, 224, group, nlmode); | |||||
pack(n, 64, 24, 123, 123, group, nlmode); | |||||
pack(n, 64, 64, 56, 56, group, nlmode); | |||||
pack(n, 128, 128, 28, 28, group, nlmode); | |||||
pack(n, 256, 256, 14, 14, group, nlmode); | |||||
pack(n, 512, 512, 7, 7, group, nlmode); | |||||
} | |||||
using namespace conv_bias; | |||||
constexpr size_t RUN = 10; | |||||
Benchmarker<ConvBias> benchmark_winograd_nchw(handle); | |||||
benchmark_winograd_nchw.set_display(false); | |||||
benchmark_winograd_nchw.set_times(RUN); | |||||
Benchmarker<ConvBias> benchmark_winograd_nchw44(handle); | |||||
benchmark_winograd_nchw44.set_display(false); | |||||
benchmark_winograd_nchw44.set_times(RUN); | |||||
std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0); | |||||
std::string winograd_nchw44_algo_name = ssprintf("WINOGRAD_NCHW44:%s", algo_name1); | |||||
for (size_t i = 0; i < args_nchw.size(); ++i) { | |||||
auto arg_nchw = args_nchw[i]; | |||||
auto arg_nchw44 = args_nchw44[i]; | |||||
TensorLayout dst_layout; | |||||
auto opr = handle->create_operator<ConvBias>(); | |||||
opr->param() = arg_nchw.param; | |||||
opr->deduce_layout( | |||||
{arg_nchw.src, dtype::Float32()}, {arg_nchw.filter, dtype::Float32()}, | |||||
{arg_nchw.bias, dtype::Float32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * arg_nchw.filter[1] * | |||||
arg_nchw.filter[2] * arg_nchw.filter[3] * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
benchmark_winograd_nchw.set_param(arg_nchw.param); | |||||
auto nchw_used = algo_benchmark<ConvBias>( | |||||
benchmark_winograd_nchw, | |||||
{arg_nchw.src, arg_nchw.filter, {}, {}, {}}, | |||||
winograd_nchw_algo_name.c_str()) / | |||||
RUN; | |||||
benchmark_winograd_nchw44.set_param(arg_nchw44.param); | |||||
auto nchw44_used = algo_benchmark<ConvBias>( | |||||
benchmark_winograd_nchw44, | |||||
{arg_nchw44.src, arg_nchw44.filter, {}, {}, {}}, | |||||
winograd_nchw44_algo_name.c_str()) / | |||||
RUN; | |||||
printf("%s %s: nchw: %f ms %f Gflops nchw44: %f ms %f GFlops " | |||||
"speedup: " | |||||
"%f\n", | |||||
arg_nchw.src.to_string().c_str(), arg_nchw.filter.to_string().c_str(), | |||||
nchw_used, computations / nchw_used, nchw44_used, | |||||
computations / nchw44_used, nchw_used / nchw44_used); | |||||
} | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); | |||||
#elif MEGDNN_ARMV7 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); | |||||
#else | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"FB_GI_F32_MK4_4x8:4:2", "FB_GI_F32_MK4_4x8:4:2", handle()); | |||||
#endif | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_4x4) { | |||||
#if MEGDNN_AARCH64 | |||||
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); | |||||
#elif MEGDNN_ARMV7 | |||||
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); | |||||
#else | |||||
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:6", handle(), 3, 4); | |||||
#endif | |||||
} | |||||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); | |||||
#elif MEGDNN_ARMV7 | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); | |||||
#else | |||||
benchmark_winograd_nchw_vs_nchw44( | |||||
"FB_GI_F32_MK4_4x8:4:6", "FB_GI_F32_MK4_4x8:4:6", handle()); | |||||
#endif | |||||
} | |||||
#endif | #endif | ||||
} // namespace test | } // namespace test | ||||