GitOrigin-RevId: ccf8b589be
release-1.10
@@ -1,725 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <algorithm> | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/arm_common/simd_macro/neon_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp32; | |||
using namespace conv_stride1; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride1::do_conv_2x2_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
//! unroll of 2 | |||
size_t ic = 0; | |||
for (; ic + 1 < IC; ic += 2) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
const float* src_ptr1 = src_ptr + IW * IH; | |||
float* outptr = dst; | |||
const float* r00 = src_ptr; | |||
const float* r01 = src_ptr + IW; | |||
const float* r10 = src_ptr1; | |||
const float* r11 = src_ptr1 + IW; | |||
const float* k0 = filter + ic * 4; | |||
const float* k1 = k0 + 4; | |||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00); | |||
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01); | |||
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1); | |||
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1); | |||
MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10); | |||
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11); | |||
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1); | |||
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1); | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, MEGDNN_SIMD_GET_LOW(_k0), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, MEGDNN_SIMD_GET_LOW(_k0), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||
_sum, _r010, MEGDNN_SIMD_GET_HIGH(_k0), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||
_sum, _r011, MEGDNN_SIMD_GET_HIGH(_k0), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, MEGDNN_SIMD_GET_LOW(_k1), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, MEGDNN_SIMD_GET_LOW(_k1), 1); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||
_sum, _r110, MEGDNN_SIMD_GET_HIGH(_k1), 0); | |||
_sum = MEGDNN_SIMD_VMLAQ_LANE( | |||
_sum, _r111, MEGDNN_SIMD_GET_HIGH(_k1), 1); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r00 += 4; | |||
r01 += 4; | |||
r10 += 4; | |||
r11 += 4; | |||
outptr += 4; | |||
} | |||
r00 += tail_step; | |||
r01 += tail_step; | |||
r10 += tail_step; | |||
r11 += tail_step; | |||
} | |||
} | |||
for (; ic < IC; ic++) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter + ic * 4; | |||
MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]); | |||
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]); | |||
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]); | |||
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1); | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2; | |||
_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum); | |||
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1); | |||
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum); | |||
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2); | |||
_sum = MEGDNN_SIMD_ADD(_sum, _sum2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
} | |||
} | |||
void conv_stride1::do_conv_3x3_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2); | |||
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0); | |||
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1); | |||
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2); | |||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4); | |||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||
MEGDNN_SIMD_STOREU(outptr2, _sum3); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1); | |||
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2); | |||
_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum1); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride1::do_conv_5x5_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3); | |||
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
MEGDNN_SIMD_STOREU(outptr2, _sum2); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
r4 += tail_step + IW; | |||
r5 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); | |||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride1::do_conv_7x7_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 | |||
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 | |||
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 | |||
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 | |||
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 | |||
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 | |||
MEGDNN_SIMD_TYPE _r05 = MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 | |||
MEGDNN_SIMD_TYPE _r06 = MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4); | |||
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1); | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3); | |||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1); | |||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||
MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4); | |||
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1); | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3); | |||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1); | |||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||
MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4); | |||
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1); | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3); | |||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1); | |||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||
MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4); | |||
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1); | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3); | |||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1); | |||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||
MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4); | |||
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8); | |||
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1); | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3); | |||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1); | |||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4); | |||
MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); | |||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); | |||
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8); | |||
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1); | |||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2); | |||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3); | |||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1); | |||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
r6 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
#include "src/common/simd_macro/epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -1,512 +0,0 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <algorithm> | |||
#include "./do_conv_stride2.h" | |||
#include "midout.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/arm_common/simd_macro/neon_helper.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fp32; | |||
using namespace conv_stride2; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride2::do_conv_2x2_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3); | |||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
filter += 4; | |||
} | |||
} | |||
void conv_stride2::do_conv_3x3_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2); | |||
MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r10 = _r1.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r1.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2); | |||
MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r20 = _r2.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r2.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1); | |||
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2); | |||
MEGDNN_SIMD_STOREU(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride2::do_conv_5x5_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4); | |||
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8); | |||
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12); | |||
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16); | |||
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20); | |||
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]); | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
MEGDNN_SIMD_TYPE _r03 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
MEGDNN_SIMD_TYPE _r04 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride2::do_conv_7x7_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); | |||
MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); | |||
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); | |||
MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0); | |||
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8); | |||
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
MEGDNN_SIMD_TYPE _r02 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
MEGDNN_SIMD_TYPE _r03 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
MEGDNN_SIMD_TYPE _r04 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
MEGDNN_SIMD_TYPE _r05 = | |||
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11 | |||
MEGDNN_SIMD_TYPE _r06 = | |||
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12 | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2); | |||
MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1); | |||
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4); | |||
MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1); | |||
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8); | |||
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2); | |||
MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2); | |||
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4); | |||
MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2); | |||
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8); | |||
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2); | |||
MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3); | |||
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4); | |||
MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3); | |||
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8); | |||
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2); | |||
MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4); | |||
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4); | |||
MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4); | |||
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8); | |||
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2); | |||
MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5); | |||
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4); | |||
MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5); | |||
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8); | |||
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); | |||
MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); | |||
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3); | |||
MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6); | |||
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8); | |||
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0]; | |||
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1]; | |||
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0]; | |||
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1]; | |||
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1); | |||
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1); | |||
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2); | |||
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2); | |||
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2); | |||
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3); | |||
MEGDNN_SIMD_STOREU(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
r5 += 8; | |||
r6 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -28,7 +28,6 @@ | |||
#include "include/megdnn/oprs/nn.h" | |||
#include "src/arm_common/conv_bias/f16/algos.h" | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/int8/stride1.h" | |||
#include "src/arm_common/conv_bias/int8/stride2.h" | |||
#include "src/arm_common/conv_bias/quint8/stride1.h" | |||
@@ -69,14 +68,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | |||
#endif | |||
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | |||
AlgoF32DirectNCHW44 f32_direct_nchw44; | |||
AlgoF32Direct f32_direct; | |||
AlgoF32DirectStride2 f32_direct_stride2; | |||
AlgoF32DirectStride1 f32_direct_stride1; | |||
AlgoI8x8x16Direct i8x8x16_direct; | |||
AlgoI8x8x16Stride2 i8x8x16_stride2; | |||
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2; | |||
@@ -127,14 +118,6 @@ public: | |||
m_direct_algos.emplace_back(&i8x8x16_stride2); | |||
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44); | |||
m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
m_direct_algos.emplace_back(&f32_direct_nchw44); | |||
m_direct_algos.emplace_back(&f32_direct_stride1); | |||
m_direct_algos.emplace_back(&f32_direct_stride2); | |||
m_direct_algos.emplace_back(&f32_direct); | |||
static CpuOprDelegationStorage<2> storage; | |||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
using MatmulFormat = param::MatrixMul::Format; | |||
@@ -145,22 +128,6 @@ public: | |||
if (is_fallback_or_naive(algo)) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
//! uncomment this when low precision mode is done | |||
#if 0 | |||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
@@ -175,27 +142,6 @@ public: | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type( | |||
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); | |||
for (auto&& algo : matmul_algos) { | |||
if (is_fallback_or_naive(algo)) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
@@ -49,15 +49,6 @@ private: | |||
class AlgoS8DirectNCHWNCHW44; | |||
class AlgoQU8DirectStride1; | |||
class AlgoQU8DirectStride2; | |||
class AlgoFP32WinogradF23_4x4; | |||
class AlgoFP32WinogradF63; | |||
class AlgoFP32WinogradF63_4x4; | |||
class AlgoFP32WinogradF54; | |||
class AlgoFP32WinogradF45; | |||
class AlgoFP32WinogradF23_4x4_NCHW44; | |||
class AlgoFP32WinogradF63_4x4_NCHW44; | |||
class AlgoFP32WinogradF73_4x4_NCHW44; | |||
class AlgoS8ChanWiseStride1NCHW44; | |||
class AlgoS8ChanWiseStride2NCHW44; | |||
@@ -78,12 +69,6 @@ private: | |||
class AlgoDotS8Direct_NCHW44; | |||
#endif | |||
class AlgoF32Direct; | |||
class AlgoF32DirectStride1; | |||
class AlgoF32DirectStride2; | |||
class AlgoF32DirectNCHWNCHW44; | |||
class AlgoF32ChannelWiseNCHW44; | |||
class AlgoF32DirectNCHW44; | |||
class AlgoI8x8x16Direct; | |||
class AlgoI8x8x16Stride2; | |||
@@ -10,6 +10,8 @@ | |||
*/ | |||
#pragma once | |||
#include "megbrain_build_config.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
@@ -0,0 +1,37 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/block_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace { | |||
// block_helper is used to calculate oh block size | |||
static inline int l2_block_helper( | |||
const int nthread, const int amount, const int size_per_unit) { | |||
//! TODO: opt config or dynamic config l2_cache_size for different ARCH | |||
constexpr int l2_cache_size = 256 * 1024; | |||
const int block_per_thread = div_ceil(amount, nthread); | |||
const int best_block = | |||
std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); | |||
const int max_block_num = div_ceil(block_per_thread, best_block); | |||
const int min_block_num = std::max(max_block_num - 1, 1); | |||
const int max_block = div_ceil(block_per_thread, max_block_num); | |||
const int min_block = div_ceil(block_per_thread, min_block_num); | |||
const int max_loss = std::abs(max_block_num * max_block - block_per_thread); | |||
const int min_loss = std::abs(min_block_num * min_block - block_per_thread); | |||
int block = max_loss > min_loss ? min_block : max_block; | |||
return block; | |||
} | |||
} // namespace | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,23 +10,22 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" | |||
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/conv_bias/img2col_helper.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/direct/multi_thread_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct.h" | |||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
/* ======================= AlgoFP32WinogradF23_4x4 ======================== */ | |||
@@ -34,10 +33,10 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 0, 0) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_2x3_4x4_f; | |||
using Strategy = winograd::winograd_gi_2x3_4x4_f; | |||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
@@ -62,8 +61,8 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF23_4x4, winograd::winograd_2x3_4x4_f, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||
AlgoFP32WinogradF23_4x4, winograd::winograd_gi_2x3_4x4_f, | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||
/* ======================= AlgoFP32WinogradF63 ======================== */ | |||
@@ -71,7 +70,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 1, 0) { | |||
using Strategy = winograd::winograd_6x3_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
@@ -95,7 +94,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
/* ======================= AlgoFP32WinogradF54 ======================== */ | |||
@@ -103,7 +102,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 2, 0) { | |||
using Strategy = winograd::winograd_5x4_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
@@ -127,7 +126,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
/* ======================= AlgoFP32WinogradF45 ======================== */ | |||
@@ -135,7 +134,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 3, 0) { | |||
using Strategy = winograd::winograd_4x5_1x1_f; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
@@ -159,7 +158,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||
/* ======================= AlgoFP32WinogradF63_4x4 ======================== */ | |||
@@ -167,7 +166,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 4, 0) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_6x3_4x4_f; | |||
@@ -197,7 +196,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||
/* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ | |||
@@ -206,7 +205,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_winograd_fp32, | |||
megdnn_fallback_winograd_fp32, | |||
midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
@@ -236,7 +235,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||
/* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ | |||
@@ -245,7 +244,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MEGDNN_MARK_USED_VAR(param); | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_winograd_fp32, | |||
megdnn_fallback_winograd_fp32, | |||
midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
@@ -276,7 +275,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||
/* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ | |||
@@ -284,7 +283,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_winograd_fp32, | |||
megdnn_fallback_winograd_fp32, | |||
midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { | |||
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) | |||
return false; | |||
@@ -314,14 +313,14 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||
AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44, | |||
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); | |||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4); | |||
/* ===================== direct algo ===================== */ | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); | |||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_kimpl); | |||
bool ConvBiasImpl::AlgoF32Direct::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto SH = fm.stride[0], SW = fm.stride[1]; | |||
@@ -341,7 +340,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable( | |||
return false; | |||
} | |||
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto wbundle = fallback::MultithreadDirectConvCommon<float, float>::get_bundle( | |||
param, large_group); | |||
@@ -426,7 +425,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -435,7 +434,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns( | |||
/* ===================== stride-1 algo ===================== */ | |||
bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
@@ -452,7 +451,7 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = | |||
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
@@ -548,7 +547,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -559,7 +558,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_ | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
@@ -575,7 +574,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
} | |||
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 1) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = | |||
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
@@ -670,7 +669,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { | |||
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/algos.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,11 +12,11 @@ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { | |||
public: | |||
AlgoFP32WinogradF23_4x4( | |||
@@ -31,7 +31,7 @@ public: | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||
@@ -50,7 +50,7 @@ public: | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||
@@ -67,7 +67,7 @@ public: | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||
@@ -86,7 +86,7 @@ public: | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F54_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||
@@ -105,7 +105,7 @@ public: | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F45_FP32) | |||
}; | |||
//===================== NCHW44 Winograd Support =====================// | |||
@@ -124,7 +124,7 @@ public: | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | |||
@@ -142,7 +142,7 @@ public: | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | |||
@@ -160,7 +160,7 @@ public: | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | |||
}; | |||
// ================================================================= // | |||
@@ -180,7 +180,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
@@ -199,7 +199,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD1_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
@@ -218,7 +218,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD2_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | |||
@@ -238,7 +238,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW44_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -258,7 +258,7 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW_NCHW44_FP32) | |||
}; | |||
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | |||
@@ -277,10 +277,10 @@ public: | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32) | |||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_CHWNWISE_NCHW44_F32) | |||
}; | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
#undef MEGDNN_WINOGRAD_ALGO_FUN_DECLARE |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,29 +10,22 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
#if defined(__ARM_FEATURE_FMA) | |||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||
#else | |||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||
#endif | |||
template <int shift> | |||
static inline void shift_src(float32x4_t rsrc[3][4]) { | |||
float32x4_t t[4]; | |||
static inline void shift_src(GI_FLOAT32_t rsrc[3][4]) { | |||
GI_FLOAT32_t t[4]; | |||
t[0] = rsrc[0][(shift + 0) % 4]; | |||
t[1] = rsrc[0][(shift + 1) % 4]; | |||
@@ -63,9 +56,9 @@ static inline void shift_src(float32x4_t rsrc[3][4]) { | |||
} | |||
template <BiasMode bias_mode> | |||
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { | |||
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { | |||
if (bias_mode == BiasMode::BIAS) { | |||
return vld1q_f32(bias); | |||
return GiLoadFloat32(bias); | |||
} else { | |||
return init; | |||
} | |||
@@ -76,35 +69,35 @@ struct compute_element { | |||
template <typename Op> | |||
static inline void call( | |||
const float*& src0, const float*& src1, const float*& src2, float*& dst, | |||
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], | |||
float32x4_t rfilter[3][3], const Op& op) { | |||
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], | |||
GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||
#define RSRC(i, j) rsrc[i][((j) + bw) % 4] | |||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||
if (has_top) { | |||
RSRC(0, 3) = vld1q_f32(src0 + 8); | |||
RSRC(0, 3) = GiLoadFloat32(src0 + 8); | |||
} | |||
{ RSRC(1, 3) = vld1q_f32(src1 + 8); } | |||
{ RSRC(1, 3) = GiLoadFloat32(src1 + 8); } | |||
if (has_bottom) { | |||
RSRC(2, 3) = vld1q_f32(src2 + 8); | |||
RSRC(2, 3) = GiLoadFloat32(src2 + 8); | |||
} | |||
if (has_top) { | |||
rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]); | |||
rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]); | |||
rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(0, 0), rfilter[0][0]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(0, 1), rfilter[0][1]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(0, 2), rfilter[0][2]); | |||
} | |||
{ | |||
rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]); | |||
rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]); | |||
rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(1, 0), rfilter[1][0]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(1, 1), rfilter[1][1]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(1, 2), rfilter[1][2]); | |||
} | |||
if (has_bottom) { | |||
rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]); | |||
rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]); | |||
rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(2, 0), rfilter[2][0]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(2, 1), rfilter[2][1]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(2, 2), rfilter[2][2]); | |||
} | |||
vst1q_f32(dst, op(rdst)); | |||
GiStoreFloat32(dst, op(rdst)); | |||
if (has_top) { | |||
src0 += 4; | |||
@@ -131,27 +124,27 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||
struct compute_element_right { | |||
template <typename Op> | |||
static inline void call( | |||
float*& dst, const float*& bias, const float32x4_t& init, | |||
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { | |||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||
if (has_top) { | |||
rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]); | |||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0][0], rfilter[0][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][2]); | |||
} | |||
{ | |||
rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]); | |||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1][0], rfilter[1][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][2]); | |||
} | |||
if (has_bottom) { | |||
rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]); | |||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2][0], rfilter[2][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][2]); | |||
} | |||
vst1q_f32(dst, op(rdst)); | |||
GiStoreFloat32(dst, op(rdst)); | |||
dst += 4; | |||
bias += 4; | |||
@@ -162,24 +155,24 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||
struct compute_element_right_pad { | |||
template <typename Op> | |||
static inline void call( | |||
float*& dst, const float*& bias, const float32x4_t& init, | |||
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { | |||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) { | |||
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init); | |||
if (has_top) { | |||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][1]); | |||
} | |||
{ | |||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][1]); | |||
} | |||
if (has_bottom) { | |||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][1]); | |||
} | |||
vst1q_f32(dst, op(rdst)); | |||
GiStoreFloat32(dst, op(rdst)); | |||
dst += 4; | |||
bias += 4; | |||
} | |||
@@ -190,22 +183,22 @@ struct compute_row { | |||
template <typename Op> | |||
static inline void call( | |||
const float*& src0, const float*& src1, const float*& src2, float*& dst, | |||
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], | |||
float32x4_t rfilter[3][3], int W, const Op& op) { | |||
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4], | |||
GI_FLOAT32_t rfilter[3][3], int W, const Op& op) { | |||
if (has_top) { | |||
rsrc[0][0] = vdupq_n_f32(0); | |||
rsrc[0][1] = vld1q_f32(src0 + 0); | |||
rsrc[0][2] = vld1q_f32(src0 + 4); | |||
rsrc[0][0] = GiZeroFloat32(); | |||
rsrc[0][1] = GiLoadFloat32(src0 + 0); | |||
rsrc[0][2] = GiLoadFloat32(src0 + 4); | |||
} | |||
{ | |||
rsrc[1][0] = vdupq_n_f32(0); | |||
rsrc[1][1] = vld1q_f32(src1 + 0); | |||
rsrc[1][2] = vld1q_f32(src1 + 4); | |||
rsrc[1][0] = GiZeroFloat32(); | |||
rsrc[1][1] = GiLoadFloat32(src1 + 0); | |||
rsrc[1][2] = GiLoadFloat32(src1 + 4); | |||
} | |||
if (has_bottom) { | |||
rsrc[2][0] = vdupq_n_f32(0); | |||
rsrc[2][1] = vld1q_f32(src2 + 0); | |||
rsrc[2][2] = vld1q_f32(src2 + 4); | |||
rsrc[2][0] = GiZeroFloat32(); | |||
rsrc[2][1] = GiLoadFloat32(src2 + 0); | |||
rsrc[2][2] = GiLoadFloat32(src2 + 4); | |||
} | |||
int w = 0; | |||
@@ -256,27 +249,27 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( | |||
int W) { | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
const float* src0 = src - W * 4; | |||
const float* src1 = src; | |||
const float* src2 = src + W * 4; | |||
float32x4_t rfilter[3][3]; | |||
rfilter[0][0] = vld1q_f32(filter + 0); | |||
rfilter[0][1] = vld1q_f32(filter + 4); | |||
rfilter[0][2] = vld1q_f32(filter + 8); | |||
rfilter[1][0] = vld1q_f32(filter + 12); | |||
rfilter[1][1] = vld1q_f32(filter + 16); | |||
rfilter[1][2] = vld1q_f32(filter + 20); | |||
rfilter[2][0] = vld1q_f32(filter + 24); | |||
rfilter[2][1] = vld1q_f32(filter + 28); | |||
rfilter[2][2] = vld1q_f32(filter + 32); | |||
float32x4_t rsrc[3][4]; | |||
GI_FLOAT32_t rfilter[3][3]; | |||
rfilter[0][0] = GiLoadFloat32(filter + 0); | |||
rfilter[0][1] = GiLoadFloat32(filter + 4); | |||
rfilter[0][2] = GiLoadFloat32(filter + 8); | |||
rfilter[1][0] = GiLoadFloat32(filter + 12); | |||
rfilter[1][1] = GiLoadFloat32(filter + 16); | |||
rfilter[1][2] = GiLoadFloat32(filter + 20); | |||
rfilter[2][0] = GiLoadFloat32(filter + 24); | |||
rfilter[2][1] = GiLoadFloat32(filter + 28); | |||
rfilter[2][2] = GiLoadFloat32(filter + 32); | |||
GI_FLOAT32_t rsrc[3][4]; | |||
compute_row<false, true, bias_mode>::call( | |||
src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,11 +12,11 @@ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace channel_wise_nchw44_float { | |||
template <BiasMode bias_mode, typename Op> | |||
@@ -25,7 +25,7 @@ void do_conv_kern_3x3_stride1_padding1( | |||
int W); | |||
} // namespace channel_wise_nchw44_float | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,29 +10,22 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
#if defined(__ARM_FEATURE_FMA) | |||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||
#else | |||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||
#endif | |||
template <int shift> | |||
static inline void shift_src(float32x4_t rsrc[6]) { | |||
float32x4_t t[6]; | |||
static inline void shift_src(GI_FLOAT32_t rsrc[6]) { | |||
GI_FLOAT32_t t[6]; | |||
t[0] = rsrc[(shift + 0) % 6]; | |||
t[1] = rsrc[(shift + 1) % 6]; | |||
@@ -48,18 +41,18 @@ static inline void shift_src(float32x4_t rsrc[6]) { | |||
rsrc[5] = t[5]; | |||
} | |||
static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { | |||
rfilter[0] = vld1q_f32(filter + 0); | |||
rfilter[1] = vld1q_f32(filter + 4); | |||
rfilter[2] = vld1q_f32(filter + 8); | |||
rfilter[3] = vld1q_f32(filter + 12); | |||
rfilter[4] = vld1q_f32(filter + 16); | |||
static inline void load_filter(const float* filter, GI_FLOAT32_t rfilter[5]) { | |||
rfilter[0] = GiLoadFloat32(filter + 0); | |||
rfilter[1] = GiLoadFloat32(filter + 4); | |||
rfilter[2] = GiLoadFloat32(filter + 8); | |||
rfilter[3] = GiLoadFloat32(filter + 12); | |||
rfilter[4] = GiLoadFloat32(filter + 16); | |||
} | |||
template <BiasMode bias_mode> | |||
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { | |||
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) { | |||
if (bias_mode == BiasMode::BIAS) { | |||
return vld1q_f32(bias); | |||
return GiLoadFloat32(bias); | |||
} else { | |||
return init; | |||
} | |||
@@ -69,27 +62,28 @@ template <int BW, int bw, BiasMode bias_mode, bool need_load_bias, bool need_do_ | |||
struct compute_element { | |||
template <typename Op> | |||
static inline void call( | |||
const float*& src, float*& dst, const float*& bias, const float32x4_t& init, | |||
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { | |||
const float*& src, float*& dst, const float*& bias, | |||
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], | |||
const Op& op) { | |||
#define RSRC(i) rsrc[((i) + bw) % 6] | |||
float32x4_t rdst; | |||
GI_FLOAT32_t rdst; | |||
if (need_load_bias) { | |||
rdst = load_bias<bias_mode>(bias, init); | |||
} else { | |||
rdst = vld1q_f32(dst); | |||
rdst = GiLoadFloat32(dst); | |||
} | |||
RSRC(5) = vld1q_f32(src + 12); | |||
RSRC(5) = GiLoadFloat32(src + 12); | |||
rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]); | |||
rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]); | |||
rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]); | |||
rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]); | |||
rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(0), rfilter[0]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(1), rfilter[1]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(2), rfilter[2]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(3), rfilter[3]); | |||
rdst = GiMlaqFloat32(rdst, RSRC(4), rfilter[4]); | |||
if (need_do_op) { | |||
rdst = op(rdst); | |||
} | |||
vst1q_f32(dst, rdst); | |||
GiStoreFloat32(dst, rdst); | |||
src += 4; | |||
dst += 4; | |||
@@ -110,29 +104,29 @@ template <size_t padding, BiasMode bias_mode, bool need_load_bias, bool need_do_ | |||
struct compute_element_right { | |||
template <typename Op> | |||
static inline void call( | |||
float*& dst, const float*& bias, const float32x4_t& init, | |||
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { | |||
float32x4_t rdst; | |||
float*& dst, const float*& bias, const GI_FLOAT32_t& init, | |||
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], const Op& op) { | |||
GI_FLOAT32_t rdst; | |||
if (need_load_bias) { | |||
rdst = load_bias<bias_mode>(bias, init); | |||
} else { | |||
rdst = vld1q_f32(dst); | |||
rdst = GiLoadFloat32(dst); | |||
} | |||
rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]); | |||
rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]); | |||
rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[0 + padding], rfilter[0]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[1 + padding], rfilter[1]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[2 + padding], rfilter[2]); | |||
if (padding < 2) { | |||
rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[3 + padding], rfilter[3]); | |||
} | |||
if (padding < 1) { | |||
rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]); | |||
rdst = GiMlaqFloat32(rdst, rsrc[4 + padding], rfilter[4]); | |||
} | |||
if (need_do_op) { | |||
rdst = op(rdst); | |||
} | |||
vst1q_f32(dst, rdst); | |||
GiStoreFloat32(dst, rdst); | |||
dst += 4; | |||
bias += 4; | |||
@@ -143,13 +137,13 @@ template <BiasMode bias_mode, bool need_load_bias, bool need_do_op> | |||
struct compute_row_src_1x5 { | |||
template <typename Op> | |||
static inline void call( | |||
const float* src, float* dst, const float* bias, const float32x4_t& init, | |||
float32x4_t rsrc[6], float32x4_t rfilter[5], int W, const Op& op) { | |||
rsrc[0] = vdupq_n_f32(0); | |||
rsrc[1] = vdupq_n_f32(0); | |||
rsrc[2] = vld1q_f32(src + 0); | |||
rsrc[3] = vld1q_f32(src + 4); | |||
rsrc[4] = vld1q_f32(src + 8); | |||
const float* src, float* dst, const float* bias, const GI_FLOAT32_t& init, | |||
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], int W, const Op& op) { | |||
rsrc[0] = GiZeroFloat32(); | |||
rsrc[1] = GiZeroFloat32(); | |||
rsrc[2] = GiLoadFloat32(src + 0); | |||
rsrc[3] = GiLoadFloat32(src + 4); | |||
rsrc[4] = GiLoadFloat32(src + 8); | |||
int w = 0; | |||
@@ -190,8 +184,8 @@ struct compute_row { | |||
template <typename Op> | |||
static inline void call( | |||
const float*& src, float*& dst, const float* filter, const float*& bias, | |||
const float32x4_t& init, float32x4_t rsrc[6], float32x4_t rfilter[5], int W, | |||
const Op& op) { | |||
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], | |||
int W, const Op& op) { | |||
if (top_padding < 1) { | |||
load_filter(filter + 0, rfilter); | |||
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call( | |||
@@ -235,13 +229,13 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( | |||
int W) { | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
float32x4_t rsrc[6]; | |||
float32x4_t rfilter[5]; | |||
GI_FLOAT32_t rsrc[6]; | |||
GI_FLOAT32_t rfilter[5]; | |||
compute_row<2, 0, bias_mode>::call( | |||
src, dst, filter, bias, init, rsrc, rfilter, W, op); |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,11 +12,11 @@ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace channel_wise_nchw44_float { | |||
template <BiasMode bias_mode, typename Op> | |||
@@ -25,7 +25,7 @@ void do_conv_kern_5x5_stride1_padding2( | |||
int W); | |||
} // namespace channel_wise_nchw44_float | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,14 +10,14 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
using conv_fun = std::function<void( | |||
const float* src, const float* filter, const float* bias, float* dst, | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp | |||
* \file dnn/src/fallback/conv_bias/int8/direct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,29 +10,28 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
template <int size> | |||
void load_vec(float32x4_t* dst, const float* src); | |||
void load_vec(GI_FLOAT32_t* dst, const float* src); | |||
#define cb(i) dst[i] = vld1q_f32(src + i * 4); | |||
#define LOAD_MACRO(n) \ | |||
template <> \ | |||
inline void load_vec<n>(float32x4_t * dst, const float* src) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||
#define cb(i) dst[i] = GiLoadFloat32(src + i * 4); | |||
#define LOAD_MACRO(n) \ | |||
template <> \ | |||
inline void load_vec<n>(GI_FLOAT32_t * dst, const float* src) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||
} | |||
LOAD_MACRO(2); | |||
LOAD_MACRO(3); | |||
@@ -46,14 +45,14 @@ LOAD_MACRO(9); | |||
#undef LOAD_MACRO | |||
template <int size> | |||
void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter); | |||
void compute_vec(GI_FLOAT32_t& dst, GI_FLOAT32_t* src, GI_FLOAT32_t* filter); | |||
#define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]); | |||
#define COMPUTE_MACRO(n) \ | |||
template <> \ | |||
inline void compute_vec<n>( \ | |||
float32x4_t & dst, float32x4_t * src, float32x4_t * filter) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||
#define cb(i) dst = GiMlaqFloat32(dst, src[i], filter[i]); | |||
#define COMPUTE_MACRO(n) \ | |||
template <> \ | |||
inline void compute_vec<n>( \ | |||
GI_FLOAT32_t & dst, GI_FLOAT32_t * src, GI_FLOAT32_t * filter) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb); \ | |||
} | |||
COMPUTE_MACRO(2); | |||
COMPUTE_MACRO(3); | |||
@@ -64,20 +63,20 @@ COMPUTE_MACRO(5); | |||
template <BiasMode bias_mode, int size> | |||
struct load_bias_vec; | |||
#define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4); | |||
#define cb_bias(i) dst[i] = GiLoadFloat32((bptr) + i * 4); | |||
#define cb_init(i) dst[i] = init; | |||
#define INIT_BIAS_MACRO(n) \ | |||
template <BiasMode bias_mode> \ | |||
struct load_bias_vec<bias_mode, n> { \ | |||
static void impl( \ | |||
float32x4_t* dst, const float32x4_t& init, const float* bptr) { \ | |||
if (bias_mode == BiasMode::BIAS) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb_bias); \ | |||
} else { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb_init); \ | |||
} \ | |||
} \ | |||
#define INIT_BIAS_MACRO(n) \ | |||
template <BiasMode bias_mode> \ | |||
struct load_bias_vec<bias_mode, n> { \ | |||
static void impl( \ | |||
GI_FLOAT32_t* dst, const GI_FLOAT32_t& init, const float* bptr) { \ | |||
if (bias_mode == BiasMode::BIAS) { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb_bias); \ | |||
} else { \ | |||
UNROLL_CALL_NOWRAPPER(n, cb_init); \ | |||
} \ | |||
} \ | |||
}; | |||
INIT_BIAS_MACRO(1); | |||
@@ -91,7 +90,7 @@ INIT_BIAS_MACRO(4); | |||
#define COMPUTE_PADDING_KERNEL() \ | |||
do { \ | |||
int iw = ow * stride - PW; \ | |||
float32x4_t result; \ | |||
GI_FLOAT32_t result; \ | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4 + ow * 4); \ | |||
for (int kh = 0; kh < fh; kh++) { \ | |||
if (kh + ih < 0 || kh + ih >= static_cast<int>(IH)) \ | |||
@@ -100,7 +99,8 @@ INIT_BIAS_MACRO(4); | |||
if (kw + iw < 0 || kw + iw >= static_cast<int>(IW)) \ | |||
continue; \ | |||
const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ | |||
result = vmlaq_f32(result, kernel[kh * fh + kw], vld1q_f32(sptr)); \ | |||
result = GiMlaqFloat32( \ | |||
result, kernel[kh * fh + kw], GiLoadFloat32(sptr)); \ | |||
} \ | |||
} \ | |||
float* output = dst + oh * OW * 4 + ow * 4; \ | |||
@@ -113,7 +113,7 @@ struct PaddingCompute { | |||
const float* src, const float* bias, float* dst, const int fh, | |||
const int stride, const size_t IH, const size_t IW, const size_t OH, | |||
const size_t OW, const size_t PH, const size_t PW, | |||
const float32x4_t* kernel, const float32x4_t& init) { | |||
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { | |||
size_t oh_start = (PH + stride - 1) / stride; | |||
size_t ow_start = (PW + stride - 1) / stride; | |||
size_t oh_end = (IH + PH - fh) / stride + 1; | |||
@@ -148,7 +148,7 @@ struct PaddingComputeK3P1 { | |||
static void compute( | |||
const float* src, const float* bias, float* dst, const size_t stride, | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | |||
const float32x4_t* kernel, const float32x4_t& init) { | |||
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) { | |||
constexpr size_t PH = 1, PW = 1, FH = 3; | |||
size_t oh_start = (PH + stride - 1) / stride; | |||
size_t ow_start = (PW + stride - 1) / stride; | |||
@@ -162,39 +162,39 @@ struct PaddingComputeK3P1 { | |||
Op op; | |||
// line one left | |||
{ | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(src)); | |||
result = vmlaq_f32(result, kernel[5], vld1q_f32(src + 4)); | |||
result = vmlaq_f32(result, kernel[7], vld1q_f32(src + IW * 4)); | |||
result = vmlaq_f32(result, kernel[8], vld1q_f32(src + IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(src)); | |||
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(src + 4)); | |||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(src + IW * 4)); | |||
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(src + IW * 4 + 4)); | |||
float* output = dst; | |||
op(result, output); | |||
} | |||
// line one mid | |||
for (size_t ow = ow_start; ow < ow_end; ow++) { | |||
int iw = ow * stride - PW; | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + ow * 4); | |||
const float* sptr = src + iw * 4; | |||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + 8)); | |||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + IW * 4 + 8)); | |||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(sptr + 8)); | |||
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(sptr + IW * 4 + 8)); | |||
float* output = dst + ow * 4; | |||
op(result, output); | |||
} | |||
// line one right | |||
if (OW != ow_end) { | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + (OW - 1) * 4); | |||
const float* sptr = src + (ow_end * stride - PW) * 4; | |||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
float* output = dst + ow_end * 4; | |||
op(result, output); | |||
} | |||
@@ -203,30 +203,36 @@ struct PaddingComputeK3P1 { | |||
int ih = oh * stride - PH; | |||
// left | |||
{ | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | |||
const float* sptr = src + ih * IW * 4; | |||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4)); | |||
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + 2 * IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[8], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); | |||
float* output = dst + oh * OW * 4; | |||
op(result, output); | |||
} | |||
// right | |||
if (OW != ow_end) { | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&result, init, bias + oh * OW * 4 + (OW - 1) * 4); | |||
const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4; | |||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + 2 * IW * 4)); | |||
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[6], GiLoadFloat32(sptr + 2 * IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4 + 4)); | |||
float* output = dst + oh * OW * 4 + ow_end * 4; | |||
op(result, output); | |||
} | |||
@@ -235,43 +241,47 @@ struct PaddingComputeK3P1 { | |||
if (OH != oh_end) { | |||
size_t oh = OH - 1; | |||
{ | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4); | |||
const float* sptr = src + (oh_end * stride - PH) * IW * 4; | |||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
float* output = dst + oh_end * OW * 4; | |||
op(result, output); | |||
} | |||
// last line mid | |||
for (size_t ow = ow_start; ow < ow_end; ow++) { | |||
int iw = ow * stride - PW; | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&result, init, bias + oh * OW * 4 + ow * 4); | |||
const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4; | |||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8)); | |||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 8)); | |||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 8)); | |||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 8)); | |||
float* output = dst + oh_end * OW * 4 + ow * 4; | |||
op(result, output); | |||
} | |||
// last line right | |||
if (OW != ow_end) { | |||
float32x4_t result; | |||
GI_FLOAT32_t result; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&result, init, bias + oh * OW * 4 + (OW - 1) * 4); | |||
const float* sptr = src + (oh_end * stride - PH) * IW * 4 + | |||
(ow_end * stride - PW) * 4; | |||
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); | |||
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); | |||
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); | |||
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); | |||
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr)); | |||
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4)); | |||
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4)); | |||
result = GiMlaqFloat32( | |||
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4)); | |||
float* output = dst + oh_end * OW * 4 + ow_end * 4; | |||
op(result, output); | |||
} | |||
@@ -286,12 +296,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||
const float* src, const float* filter, const float* bias, float* dst, | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | |||
const size_t PH, const size_t PW) { | |||
float32x4_t kernel[4]; | |||
GI_FLOAT32_t kernel[4]; | |||
load_vec<4>(kernel, filter); | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
size_t oh_start = PH; | |||
size_t ow_start = PW; | |||
@@ -315,12 +325,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||
size_t iw = ow - ow_start; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][4]; | |||
GI_FLOAT32_t dst_v[2][4]; | |||
load_bias_vec<bias_mode, 4>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 4>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[3][5]; | |||
GI_FLOAT32_t src_v[3][5]; | |||
load_vec<5>(src_v[0], input); | |||
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | |||
load_vec<5>(src_v[1], input + IW * 4); | |||
@@ -338,12 +348,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||
size_t iw = ow - ow_start; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2]; | |||
GI_FLOAT32_t dst_v[2]; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[3][2]; | |||
GI_FLOAT32_t src_v[3][2]; | |||
load_vec<2>(src_v[0], input); | |||
compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]); | |||
load_vec<2>(src_v[1], input + IW * 4); | |||
@@ -363,10 +373,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||
size_t iw = ow - ow_start; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[1][4]; | |||
GI_FLOAT32_t dst_v[1][4]; | |||
load_bias_vec<bias_mode, 4>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
load_vec<5>(src_v[0], input); | |||
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); | |||
load_vec<5>(src_v[1], input + IW * 4); | |||
@@ -379,10 +389,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( | |||
size_t iw = ow - ow_start; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][2]; | |||
GI_FLOAT32_t src_v[2][2]; | |||
load_vec<2>(src_v[0], input); | |||
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | |||
load_vec<2>(src_v[1], input + IW * 4); | |||
@@ -405,12 +415,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||
return; | |||
} | |||
float32x4_t kernel[9]; | |||
GI_FLOAT32_t kernel[9]; | |||
load_vec<9>(kernel, filter); | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
size_t oh_start = PH; | |||
size_t ow_start = PW; | |||
@@ -428,12 +438,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][4]; | |||
GI_FLOAT32_t dst_v[2][4]; | |||
load_bias_vec<bias_mode, 4>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 4>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][6]; | |||
GI_FLOAT32_t src_v[2][6]; | |||
load_vec<6>(src_v[0], input); | |||
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | |||
compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]); | |||
@@ -472,12 +482,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2]; | |||
GI_FLOAT32_t dst_v[2]; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][3]; | |||
GI_FLOAT32_t src_v[2][3]; | |||
load_vec<3>(src_v[0], input); | |||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | |||
load_vec<3>(src_v[1], input + IW * 4); | |||
@@ -500,10 +510,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[4]; | |||
GI_FLOAT32_t dst_v[4]; | |||
load_bias_vec<bias_mode, 4>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][6]; | |||
GI_FLOAT32_t src_v[2][6]; | |||
load_vec<6>(src_v[0], input); | |||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | |||
compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]); | |||
@@ -526,10 +536,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[3][3]; | |||
GI_FLOAT32_t src_v[3][3]; | |||
load_vec<3>(src_v[0], input); | |||
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | |||
load_vec<3>(src_v[1], input + IW * 4); | |||
@@ -553,9 +563,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
} | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
size_t oh_start = PH; | |||
size_t ow_start = PW; | |||
@@ -564,7 +574,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
if (PH || PW) { | |||
PaddingCompute<bias_mode, Op>::compute( | |||
src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW, | |||
reinterpret_cast<const float32x4_t*>(filter), init); | |||
reinterpret_cast<const GI_FLOAT32_t*>(filter), init); | |||
} | |||
size_t oh = oh_start; | |||
for (; oh + 1 < oh_end; oh += 2) { | |||
@@ -574,13 +584,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][2]; | |||
GI_FLOAT32_t dst_v[2][2]; | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t kernel[2][5]; | |||
float32x4_t src_v[2][6]; | |||
GI_FLOAT32_t kernel[2][5]; | |||
GI_FLOAT32_t src_v[2][6]; | |||
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | |||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | |||
load_vec<6>(src, input + i * IW * 4); \ | |||
@@ -613,13 +623,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][1]; | |||
GI_FLOAT32_t dst_v[2][1]; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 1>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t kernel[2][5]; | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t kernel[2][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | |||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | |||
load_vec<6>(src, input + i * IW * 4); \ | |||
@@ -652,11 +662,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[1][2]; | |||
GI_FLOAT32_t dst_v[1][2]; | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t kernel[2][5]; | |||
float32x4_t src_v[2][6]; | |||
GI_FLOAT32_t kernel[2][5]; | |||
GI_FLOAT32_t src_v[2][6]; | |||
#define COMPUTE_5X5_2(i, dst, src, kernel) \ | |||
load_vec<5>(kernel, filter + i * 5 * 4); \ | |||
load_vec<6>(src, input + i * IW * 4); \ | |||
@@ -679,11 +689,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||
size_t iw = ow - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t kernel[2][5]; | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t kernel[2][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
#define COMPUTE_5X5_1(i, dst, src, kernel) \ | |||
load_vec<5>(kernel, filter + i * 5 * 4); \ | |||
load_vec<6>(src, input + i * IW * 4); \ | |||
@@ -709,12 +719,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||
const float* src, const float* filter, const float* bias, float* dst, | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | |||
const size_t PH, const size_t PW) { | |||
float32x4_t kernel[4]; | |||
GI_FLOAT32_t kernel[4]; | |||
load_vec<4>(kernel, filter); | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
size_t oh_start = (PH + 1) / 2; | |||
size_t ow_start = (PW + 1) / 2; | |||
@@ -737,10 +747,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[4]; | |||
GI_FLOAT32_t dst_v[4]; | |||
load_bias_vec<bias_mode, 4>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][8]; | |||
GI_FLOAT32_t src_v[2][8]; | |||
load_vec<8>(src_v[0], input); | |||
COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); | |||
load_vec<8>(src_v[1], input + IW * 4); | |||
@@ -753,10 +763,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][2]; | |||
GI_FLOAT32_t src_v[2][2]; | |||
load_vec<2>(src_v[0], input); | |||
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); | |||
load_vec<2>(src_v[1], input + IW * 4); | |||
@@ -773,12 +783,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||
const float* src, const float* filter, const float* bias, float* dst, | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | |||
const size_t PH, const size_t PW) { | |||
float32x4_t kernel[9]; | |||
GI_FLOAT32_t kernel[9]; | |||
load_vec<9>(kernel, filter); | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
size_t oh_start = (PH + 1) / 2; | |||
size_t ow_start = (PW + 1) / 2; | |||
@@ -799,12 +809,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][2]; | |||
GI_FLOAT32_t dst_v[2][2]; | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
load_vec<5>(src_v[0], input); | |||
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); | |||
compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]); | |||
@@ -830,12 +840,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2]; | |||
GI_FLOAT32_t dst_v[2]; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t src_v[2][3]; | |||
GI_FLOAT32_t src_v[2][3]; | |||
load_vec<3>(src_v[0], input); | |||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | |||
load_vec<3>(src_v[1], input + IW * 4); | |||
@@ -859,10 +869,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2]; | |||
GI_FLOAT32_t dst_v[2]; | |||
load_bias_vec<bias_mode, 2>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[3][5]; | |||
GI_FLOAT32_t src_v[3][5]; | |||
load_vec<5>(src_v[0], input); | |||
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); | |||
compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]); | |||
@@ -878,10 +888,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( | |||
size_t iw = ow * 2 - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t src_v[3][3]; | |||
GI_FLOAT32_t src_v[3][3]; | |||
load_vec<3>(src_v[0], input); | |||
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); | |||
load_vec<3>(src_v[1], input + IW * 4); | |||
@@ -899,9 +909,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | |||
const size_t PH, const size_t PW) { | |||
Op op; | |||
float32x4_t init = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t init = GiZeroFloat32(); | |||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
init = vld1q_f32(bias); | |||
init = GiLoadFloat32(bias); | |||
} | |||
constexpr size_t stride = 2; | |||
size_t oh_start = (PH + stride - 1) / stride; | |||
@@ -911,7 +921,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||
if (PH || PW) { | |||
PaddingCompute<bias_mode, Op>::compute( | |||
src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW, | |||
reinterpret_cast<const float32x4_t*>(filter), init); | |||
reinterpret_cast<const GI_FLOAT32_t*>(filter), init); | |||
} | |||
size_t oh = oh_start; | |||
for (; oh + 1 < oh_end; oh += 2) { | |||
@@ -921,13 +931,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||
size_t iw = ow * stride - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2][2]; | |||
GI_FLOAT32_t dst_v[2][2]; | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 2>::impl( | |||
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t kernel[3][5]; | |||
float32x4_t src_v[2][7]; | |||
GI_FLOAT32_t kernel[3][5]; | |||
GI_FLOAT32_t src_v[2][7]; | |||
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ | |||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | |||
load_vec<7>(src, input + i * IW * 4); \ | |||
@@ -965,13 +975,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||
size_t iw = ow * stride - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v[2]; | |||
GI_FLOAT32_t dst_v[2]; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[0], init, bias + oh * OW * 4 + ow * 4); | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); | |||
float32x4_t kernel[3][5]; | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t kernel[3][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ | |||
load_vec<5>(kernel0, filter + i * 5 * 4); \ | |||
load_vec<5>(src, input + i * IW * 4); \ | |||
@@ -1010,11 +1020,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( | |||
size_t iw = ow * stride - PW; | |||
const float* input = src + ih * IW * 4 + iw * 4; | |||
float* output = dst + oh * OW * 4 + ow * 4; | |||
float32x4_t dst_v; | |||
GI_FLOAT32_t dst_v; | |||
load_bias_vec<bias_mode, 1>::impl( | |||
&dst_v, init, bias + oh * OW * 4 + ow * 4); | |||
float32x4_t kernel[2][5]; | |||
float32x4_t src_v[2][5]; | |||
GI_FLOAT32_t kernel[2][5]; | |||
GI_FLOAT32_t src_v[2][5]; | |||
#define COMPUTE_5X5_1(i, dst, src, kernel) \ | |||
load_vec<5>(kernel, filter + i * 5 * 4); \ | |||
load_vec<6>(src, input + i * IW * 4); \ |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,11 +12,11 @@ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace channel_wise_nchw44_float { | |||
#define KERN(stride, i) \ | |||
@@ -37,7 +37,7 @@ KERN(stride2, 5) | |||
#undef KERN | |||
} // namespace channel_wise_nchw44_float | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,18 +9,18 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct.h" | |||
#include <cstring> | |||
#include "include/megdnn/oprs.h" | |||
#include "midout.h" | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
MIDOUT_DECL(megdnn_arm_conv_f32) | |||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
MIDOUT_DECL(megdnn_gi_conv_f32) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
using namespace fp32; | |||
using namespace conv_bias; | |||
@@ -34,65 +34,65 @@ struct do_pixel_proxy { | |||
const int ow); | |||
}; | |||
#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); | |||
#define LOAD_OUT \ | |||
if (width < 4) { \ | |||
auto load_less_4 = [](float* dst, float32x4_t& data) { \ | |||
if (width == 1u) { \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||
} else if (width == 2u) { \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||
} else if (width == 3u) { \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
load_less_4(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
load_less_4(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
load_less_4(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
load_less_4(dst + 3 * OW, out3); \ | |||
} else { \ | |||
if (height > 0) \ | |||
out0 = vld1q_f32(dst + 0 * OW); \ | |||
if (height > 1) \ | |||
out1 = vld1q_f32(dst + 1 * OW); \ | |||
if (height > 2) \ | |||
out2 = vld1q_f32(dst + 2 * OW); \ | |||
if (height > 3) \ | |||
out3 = vld1q_f32(dst + 3 * OW); \ | |||
} | |||
#define cb_store(i) vst1q_lane_f32(dst + i, data, i); | |||
#define STORE_OUT \ | |||
#define cb_load(i) data = GiLd1qLaneFloat32(dst + i, data, i); | |||
#define LOAD_OUT \ | |||
if (width < 4) { \ | |||
auto store_less_4 = [](float* dst, float32x4_t& data) { \ | |||
auto load_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ | |||
if (width == 1u) { \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_load); \ | |||
} else if (width == 2u) { \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_load); \ | |||
} else if (width == 3u) { \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_load); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
store_less_4(dst + 0 * OW, out0); \ | |||
load_less_4(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
store_less_4(dst + 1 * OW, out1); \ | |||
load_less_4(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
store_less_4(dst + 2 * OW, out2); \ | |||
load_less_4(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
store_less_4(dst + 3 * OW, out3); \ | |||
load_less_4(dst + 3 * OW, out3); \ | |||
} else { \ | |||
if (height >= 1) \ | |||
vst1q_f32(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
vst1q_f32(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
vst1q_f32(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
vst1q_f32(dst + 3 * OW, out3); \ | |||
if (height > 0) \ | |||
out0 = GiLoadFloat32(dst + 0 * OW); \ | |||
if (height > 1) \ | |||
out1 = GiLoadFloat32(dst + 1 * OW); \ | |||
if (height > 2) \ | |||
out2 = GiLoadFloat32(dst + 2 * OW); \ | |||
if (height > 3) \ | |||
out3 = GiLoadFloat32(dst + 3 * OW); \ | |||
} | |||
#define cb_store(i) GiStoreLane##i##Float32(dst + i, data); | |||
#define STORE_OUT \ | |||
if (width < 4) { \ | |||
auto store_less_4 = [](float* dst, GI_FLOAT32_t& data) { \ | |||
if (width == 1u) { \ | |||
UNROLL_CALL_NOWRAPPER(1, cb_store); \ | |||
} else if (width == 2u) { \ | |||
UNROLL_CALL_NOWRAPPER(2, cb_store); \ | |||
} else if (width == 3u) { \ | |||
UNROLL_CALL_NOWRAPPER(3, cb_store); \ | |||
} \ | |||
}; \ | |||
if (height >= 1) \ | |||
store_less_4(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
store_less_4(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
store_less_4(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
store_less_4(dst + 3 * OW, out3); \ | |||
} else { \ | |||
if (height >= 1) \ | |||
GiStoreFloat32(dst + 0 * OW, out0); \ | |||
if (height >= 2) \ | |||
GiStoreFloat32(dst + 1 * OW, out1); \ | |||
if (height >= 3) \ | |||
GiStoreFloat32(dst + 2 * OW, out2); \ | |||
if (height >= 4) \ | |||
GiStoreFloat32(dst + 3 * OW, out3); \ | |||
} | |||
template <int height, int width> | |||
@@ -104,33 +104,33 @@ struct do_pixel_proxy<1, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -145,45 +145,45 @@ struct do_pixel_proxy<2, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -198,57 +198,57 @@ struct do_pixel_proxy<3, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -263,69 +263,69 @@ struct do_pixel_proxy<4, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -340,81 +340,81 @@ struct do_pixel_proxy<5, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -429,94 +429,94 @@ struct do_pixel_proxy<6, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||
inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr5); | |||
out0 = GiMlaqFloat32(out0, inp, kr5); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr5); | |||
out1 = GiMlaqFloat32(out1, inp, kr5); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr5); | |||
out2 = GiMlaqFloat32(out2, inp, kr5); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 8 * IW); | |||
inp = GiLoadFloat32(src_dd + 8 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr5); | |||
out3 = GiMlaqFloat32(out3, inp, kr5); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -531,106 +531,106 @@ struct do_pixel_proxy<7, height, width> { | |||
(void)IH; | |||
(void)OH; | |||
const int ih = oh, iw = ow; | |||
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, | |||
kr6, inp; | |||
src += ih * IW + iw; | |||
dst += oh * OW + ow; | |||
LOAD_OUT; | |||
for (int fw = 0; fw < FW; ++fw) { | |||
const float* src_dd = src + fw; | |||
kr0 = vdupq_n_f32(filter[0 * FW + fw]); | |||
kr1 = vdupq_n_f32(filter[1 * FW + fw]); | |||
kr2 = vdupq_n_f32(filter[2 * FW + fw]); | |||
kr3 = vdupq_n_f32(filter[3 * FW + fw]); | |||
kr4 = vdupq_n_f32(filter[4 * FW + fw]); | |||
kr5 = vdupq_n_f32(filter[5 * FW + fw]); | |||
kr6 = vdupq_n_f32(filter[6 * FW + fw]); | |||
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]); | |||
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]); | |||
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]); | |||
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]); | |||
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]); | |||
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]); | |||
kr6 = GiBroadcastFloat32(filter[6 * FW + fw]); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 0 * IW); | |||
inp = GiLoadFloat32(src_dd + 0 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr0); | |||
out0 = GiMlaqFloat32(out0, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 1 * IW); | |||
inp = GiLoadFloat32(src_dd + 1 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr1); | |||
out0 = GiMlaqFloat32(out0, inp, kr1); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr0); | |||
out1 = GiMlaqFloat32(out1, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 2 * IW); | |||
inp = GiLoadFloat32(src_dd + 2 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr2); | |||
out0 = GiMlaqFloat32(out0, inp, kr2); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr1); | |||
out1 = GiMlaqFloat32(out1, inp, kr1); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr0); | |||
out2 = GiMlaqFloat32(out2, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 3 * IW); | |||
inp = GiLoadFloat32(src_dd + 3 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr3); | |||
out0 = GiMlaqFloat32(out0, inp, kr3); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr2); | |||
out1 = GiMlaqFloat32(out1, inp, kr2); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr1); | |||
out2 = GiMlaqFloat32(out2, inp, kr1); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr0); | |||
out3 = GiMlaqFloat32(out3, inp, kr0); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 4 * IW); | |||
inp = GiLoadFloat32(src_dd + 4 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr4); | |||
out0 = GiMlaqFloat32(out0, inp, kr4); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr3); | |||
out1 = GiMlaqFloat32(out1, inp, kr3); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr2); | |||
out2 = GiMlaqFloat32(out2, inp, kr2); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr1); | |||
out3 = GiMlaqFloat32(out3, inp, kr1); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 5 * IW); | |||
inp = GiLoadFloat32(src_dd + 5 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr5); | |||
out0 = GiMlaqFloat32(out0, inp, kr5); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr4); | |||
out1 = GiMlaqFloat32(out1, inp, kr4); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr3); | |||
out2 = GiMlaqFloat32(out2, inp, kr3); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr2); | |||
out3 = GiMlaqFloat32(out3, inp, kr2); | |||
if (height > 0) | |||
inp = vld1q_f32(src_dd + 6 * IW); | |||
inp = GiLoadFloat32(src_dd + 6 * IW); | |||
if (height > 0) | |||
out0 = vmlaq_f32(out0, inp, kr6); | |||
out0 = GiMlaqFloat32(out0, inp, kr6); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr5); | |||
out1 = GiMlaqFloat32(out1, inp, kr5); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr4); | |||
out2 = GiMlaqFloat32(out2, inp, kr4); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr3); | |||
out3 = GiMlaqFloat32(out3, inp, kr3); | |||
if (height > 1) | |||
inp = vld1q_f32(src_dd + 7 * IW); | |||
inp = GiLoadFloat32(src_dd + 7 * IW); | |||
if (height > 1) | |||
out1 = vmlaq_f32(out1, inp, kr6); | |||
out1 = GiMlaqFloat32(out1, inp, kr6); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr5); | |||
out2 = GiMlaqFloat32(out2, inp, kr5); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr4); | |||
out3 = GiMlaqFloat32(out3, inp, kr4); | |||
if (height > 2) | |||
inp = vld1q_f32(src_dd + 8 * IW); | |||
inp = GiLoadFloat32(src_dd + 8 * IW); | |||
if (height > 2) | |||
out2 = vmlaq_f32(out2, inp, kr6); | |||
out2 = GiMlaqFloat32(out2, inp, kr6); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr5); | |||
out3 = GiMlaqFloat32(out3, inp, kr5); | |||
if (height > 3) | |||
inp = vld1q_f32(src_dd + 9 * IW); | |||
inp = GiLoadFloat32(src_dd + 9 * IW); | |||
if (height > 3) | |||
out3 = vmlaq_f32(out3, inp, kr6); | |||
out3 = GiMlaqFloat32(out3, inp, kr6); | |||
} | |||
STORE_OUT; | |||
} | |||
@@ -836,31 +836,31 @@ void conv_bias::kern_direct( | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} | |||
@@ -872,31 +872,31 @@ void conv_bias::kern_direct( | |||
} while (0) | |||
switch (FH) { | |||
case 1: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); } | |||
MIDOUT_END(); | |||
break; | |||
case 2: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); } | |||
MIDOUT_END(); | |||
break; | |||
case 3: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); } | |||
MIDOUT_END(); | |||
break; | |||
case 4: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); } | |||
MIDOUT_END(); | |||
break; | |||
case 5: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); } | |||
MIDOUT_END(); | |||
break; | |||
case 6: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); } | |||
MIDOUT_END(); | |||
break; | |||
case 7: | |||
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); } | |||
MIDOUT_END(); | |||
break; | |||
} |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/direct.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -13,7 +13,7 @@ | |||
#include <cstddef> | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace fp32 { | |||
namespace conv_bias { | |||
@@ -23,7 +23,7 @@ void kern_direct( | |||
} // namespace conv_bias | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,12 +11,12 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace conv_bias { | |||
template <> | |||
void pack_src_fp32_nchw44<1>( | |||
@@ -51,23 +51,23 @@ static inline void odd_even_split_iw8_even( | |||
const int src_offset = src_idx * ic_step; | |||
const int even_offset = iw_idx / 2 * ic_step; | |||
const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||
float32x4_t temp[8]; | |||
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||
GI_FLOAT32_t temp[8]; | |||
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); | |||
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); | |||
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); | |||
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); | |||
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); | |||
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); | |||
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); | |||
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); | |||
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||
} | |||
static inline void odd_even_split_iw8_odd( | |||
@@ -77,23 +77,23 @@ static inline void odd_even_split_iw8_odd( | |||
const int src_offset = src_idx * ic_step; | |||
const int even_offset = (iw_idx + 1) / 2 * ic_step; | |||
const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||
float32x4_t temp[8]; | |||
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||
GI_FLOAT32_t temp[8]; | |||
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step); | |||
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step); | |||
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step); | |||
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step); | |||
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step); | |||
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step); | |||
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step); | |||
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step); | |||
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||
} | |||
} // namespace | |||
@@ -104,7 +104,7 @@ void pack_src_fp32_nchw44<2>( | |||
const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { | |||
constexpr int ic_step = 4; | |||
int odd_start = megdnn::div_ceil(iw2, 2); | |||
float32x4_t zero_v = vdupq_n_f32(0.f); | |||
GI_FLOAT32_t zero_v = GiZeroFloat32(); | |||
MEGDNN_MARK_USED_VAR(ph); | |||
bool even_start = pw % 2 == 0; | |||
rep_step(ic_idx, ic, ic_step) { | |||
@@ -115,9 +115,10 @@ void pack_src_fp32_nchw44<2>( | |||
int iw_idx = 0; | |||
rep(idx, pw) { | |||
if (iw_idx % 2 == 0) { | |||
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||
} else { | |||
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||
GiStoreFloat32( | |||
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||
} | |||
++iw_idx; | |||
} | |||
@@ -136,21 +137,22 @@ void pack_src_fp32_nchw44<2>( | |||
} | |||
for (; src_idx < iw; ++src_idx) { | |||
if (iw_idx % 2 == 0) { | |||
vst1q_f32( | |||
GiStoreFloat32( | |||
sptr_base + iw_idx / 2 * ic_step, | |||
vld1q_f32(sptr + src_idx * ic_step)); | |||
GiLoadFloat32(sptr + src_idx * ic_step)); | |||
} else { | |||
vst1q_f32( | |||
GiStoreFloat32( | |||
sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||
vld1q_f32(sptr + src_idx * ic_step)); | |||
GiLoadFloat32(sptr + src_idx * ic_step)); | |||
} | |||
++iw_idx; | |||
} | |||
rep(idx, pad_right) { | |||
if (iw_idx % 2 == 0) { | |||
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||
} else { | |||
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||
GiStoreFloat32( | |||
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); | |||
} | |||
++iw_idx; | |||
} | |||
@@ -163,7 +165,7 @@ void pack_src_fp32_nchw44<2>( | |||
} | |||
} // namespace conv_bias | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_NO_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_NO_BIAS(2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_NO_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_NO_BIAS(3); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_NO_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_NO_BIAS(5); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||
INSTANTIATION_CONV_S1_NO_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||
INSTANTIATION_CONV_S2_NO_BIAS(7); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,16 +12,15 @@ | |||
*/ | |||
#include "megdnn/arch.h" | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
template < | |||
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||
}; | |||
#define cb2(step, lane, ow_block) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | |||
c[1][step] = vfmaq_laneq_f32( \ | |||
c[1][step] = GiSimdFmaLane( \ | |||
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | |||
#define cb(step, lane, ow_block) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
#define cb(step, lane, ow_block) \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | |||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | |||
@@ -122,7 +121,7 @@ public: | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | |||
int ow_block> | |||
struct KerNeonXXs1Nchw44FP32 { | |||
struct KerGiXXs1Nchw44FP32 { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -130,7 +129,7 @@ struct KerNeonXXs1Nchw44FP32 { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -147,20 +146,20 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -172,7 +171,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -189,24 +188,24 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -217,7 +216,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -234,36 +233,36 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | |||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); | |||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | |||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); | |||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -275,7 +274,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -292,46 +291,46 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][ic_step]; | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step); | |||
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step); | |||
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | |||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step); | |||
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | |||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step); | |||
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | |||
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[4] = GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step); | |||
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | |||
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[5] = GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step); | |||
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -352,10 +351,10 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
#if MEGDNN_ARMV7 | |||
constexpr int big_oc_step = 4; | |||
#else | |||
#if MEGDNN_AARCH64 | |||
constexpr int big_oc_step = 8; | |||
#else | |||
constexpr int big_oc_step = 4; | |||
#endif | |||
constexpr int oc_step = 4; | |||
constexpr int ih_step = 1; | |||
@@ -381,9 +380,9 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = KerNeonXXs1Nchw44FP32< \ | |||
kern_big_oc_remain = KerGiXXs1Nchw44FP32< \ | |||
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | |||
kern_small_oc_remain = KerNeonXXs1Nchw44FP32< \ | |||
kern_small_oc_remain = KerGiXXs1Nchw44FP32< \ | |||
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | |||
break; | |||
@@ -402,7 +401,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
const int bias_offset = | |||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | |||
KerNeonXXs1Nchw44FP32< | |||
KerGiXXs1Nchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | |||
@@ -434,7 +433,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
const int bias_offset = | |||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | |||
KerNeonXXs1Nchw44FP32< | |||
KerGiXXs1Nchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
bias + bias_offset, dst + dst_offset, ic, ih, iw, |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,16 +12,15 @@ | |||
*/ | |||
#include "megdnn/arch.h" | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
template < | |||
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> { | |||
}; | |||
#define cb2(step, lane, ow_block) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ | |||
c[1][step] = vfmaq_laneq_f32( \ | |||
c[1][step] = GiSimdFmaLane( \ | |||
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); | |||
#define cb(step, lane, ow_block) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
#define cb(step, lane, ow_block) \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); | |||
#define SHIFT_CAL_HELPER(ow_block, remain_w) \ | |||
@@ -122,7 +121,7 @@ public: | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | |||
int ow_block> | |||
struct KerNeonXXs2Nchw44FP32 { | |||
struct KerGiXXs2Nchw44FP32 { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -130,7 +129,7 @@ struct KerNeonXXs2Nchw44FP32 { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -147,36 +146,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][4]; | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][4]; | |||
/////////row 0///////////// | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
src_ptr_odd += ld_src_iw; | |||
weight_ptr += ld_weight_fh; | |||
/////////row 1///////////// | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -188,7 +187,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -205,62 +204,62 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; | |||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][4]; | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][4]; | |||
/////////row 0///////////// | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
src_ptr_odd += ld_src_iw; | |||
weight_ptr += ld_weight_fh; | |||
/////////row 1///////////// | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
src_ptr_odd += ld_src_iw; | |||
weight_ptr += ld_weight_fh; | |||
//////////row 2///////////// | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src_ptr += ld_src_iw; | |||
@@ -272,7 +271,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -289,7 +288,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
@@ -297,28 +296,28 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][4]; | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][4]; | |||
// even element | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | |||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); | |||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
// odd element | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | |||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); | |||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
@@ -337,7 +336,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||
* src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 | |||
**/ | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block> | |||
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr_origin, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -354,7 +353,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
const int ld_src_ic = ih * iw; | |||
const int ld_src_iw = iw * oc_step; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][ow_block]; | |||
GI_FLOAT32_t c[c_dim][ow_block]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
@@ -362,36 +361,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; | |||
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { | |||
float32x4_t src[ow_block]; | |||
float32x4_t weight[c_dim][4]; | |||
GI_FLOAT32_t src[ow_block]; | |||
GI_FLOAT32_t weight[c_dim][4]; | |||
// even element | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len); | |||
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | |||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len); | |||
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | |||
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len); | |||
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
// odd element | |||
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0); | |||
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | |||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len); | |||
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | |||
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | |||
src[1] = GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len); | |||
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); | |||
@@ -414,10 +413,10 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
constexpr int fh = filter_size; | |||
constexpr int fw = filter_size; | |||
constexpr int ic_step = 4; | |||
#if MEGDNN_ARMV7 | |||
constexpr int big_oc_step = 4; | |||
#else | |||
#if MEGDNN_AARCH64 | |||
constexpr int big_oc_step = 8; | |||
#else | |||
constexpr int big_oc_step = 4; | |||
#endif | |||
constexpr int oc_step = 4; | |||
constexpr int ih_step = 1; | |||
@@ -444,9 +443,9 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = KerNeonXXs2Nchw44FP32< \ | |||
kern_big_oc_remain = KerGiXXs2Nchw44FP32< \ | |||
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ | |||
kern_small_oc_remain = KerNeonXXs2Nchw44FP32< \ | |||
kern_small_oc_remain = KerGiXXs2Nchw44FP32< \ | |||
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ | |||
break; | |||
@@ -469,7 +468,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
const int bias_offset = | |||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | |||
KerNeonXXs2Nchw44FP32< | |||
KerGiXXs2Nchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
bias + bias_offset, dst + dst_offset, ic, ih, iw, | |||
@@ -510,7 +509,7 @@ void conv_bias::conv_direct_fp32_nchw44( | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
const int bias_offset = | |||
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; | |||
KerNeonXXs2Nchw44FP32< | |||
KerGiXXs2Nchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
bias + bias_offset, dst + dst_offset, ic, ih, iw, |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(2, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(2, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(2, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(2, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(3, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(3, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(3, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(3, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(5, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(5, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(5, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(5, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(7, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(7, 1); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BIAS(7, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | |||
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,6 +10,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||
INSTANCE_CONV_NO_BIAS(7, 2); | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,20 +11,19 @@ | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#if MEGDNN_ARMV7 | |||
#include "src/armv7/matrix_mul/asm/common.h" | |||
#endif | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
/** | |||
@@ -50,15 +49,15 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> { | |||
}; | |||
#define cb(step) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | |||
(step * stride + src_idx) % 4); \ | |||
c[1][step] = vfmaq_laneq_f32( \ | |||
c[1][step] = GiSimdFmaLane( \ | |||
c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \ | |||
(step * stride + src_idx) % 4); | |||
#define cb2(step) \ | |||
c[0][step] = vfmaq_laneq_f32( \ | |||
c[0][step] = GiSimdFmaLane( \ | |||
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ | |||
(step * stride + src_idx) % 4); | |||
@@ -127,7 +126,7 @@ public: | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, | |||
int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG> | |||
struct KerNeonXXs2NchwNchw44FP32 { | |||
struct KerGiXXs2NchwNchw44FP32 { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -136,8 +135,7 @@ struct KerNeonXXs2NchwNchw44FP32 { | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | |||
int ow_block> | |||
struct KerNeonXXs2NchwNchw44FP32< | |||
bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -154,16 +152,16 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
const int ld_weight_ic = oc_step * filter_size * filter_size; | |||
const int ld_src_ic = ih * iw; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][8]; | |||
GI_FLOAT32_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
float32x4_t src[src_reg_size]; | |||
float32x4_t weight[c_dim][filter_size]; | |||
GI_FLOAT32_t src[src_reg_size]; | |||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||
#define KERNEL_CB(step) \ | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \ | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ | |||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | |||
@@ -186,8 +184,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | |||
int ow_block> | |||
struct KerNeonXXs2NchwNchw44FP32< | |||
bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -204,16 +201,16 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
const int ld_weight_ic = oc_step * filter_size * filter_size; | |||
const int ld_src_ic = ih * iw; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][8]; | |||
GI_FLOAT32_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
float32x4_t src[src_reg_size]; | |||
float32x4_t weight[c_dim][filter_size]; | |||
GI_FLOAT32_t src[src_reg_size]; | |||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||
#define KERNEL_CB(step) \ | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \ | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \ | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \ | |||
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ | |||
@@ -233,8 +230,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | |||
int ow_block> | |||
struct KerNeonXXs2NchwNchw44FP32< | |||
bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -251,32 +247,32 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
const int ld_weight_ic = oc_step * filter_size * filter_size; | |||
const int ld_src_ic = ih * iw; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][8]; | |||
GI_FLOAT32_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
float32x4_t src[src_reg_size]; | |||
float32x4_t weight[c_dim][filter_size]; | |||
GI_FLOAT32_t src[src_reg_size]; | |||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||
// row 0 | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | |||
// row 1 | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); | |||
// row 2 | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>( | |||
src, src_ptr + 2 * iw, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||
@@ -292,7 +288,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
#if MEGDNN_ARMV7 | |||
template <BiasMode bias_mode, typename Op> | |||
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -310,7 +306,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||
const int ld_src_ic_skip_bytes = | |||
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[1][8]; | |||
GI_FLOAT32_t c[1][8]; | |||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||
const int img_stride = ih * iw; | |||
constexpr int filter_stride = filter_size * filter_size * oc_step; | |||
@@ -464,8 +460,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> { | |||
}; | |||
template <BiasMode bias_mode, typename Op> | |||
struct KerNeonXXs2NchwNchw44FP32< | |||
bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -483,7 +478,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
const int ld_src_ic_skip_bytes = | |||
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[1][8]; | |||
GI_FLOAT32_t c[1][8]; | |||
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||
/** | |||
* c q8-q15 | |||
@@ -626,8 +621,7 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
template < | |||
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, | |||
int ow_block> | |||
struct KerNeonXXs2NchwNchw44FP32< | |||
bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { | |||
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { | |||
static void impl( | |||
const float32_t* src_ptr, const float32_t* weight_ptr, | |||
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, | |||
@@ -644,22 +638,22 @@ struct KerNeonXXs2NchwNchw44FP32< | |||
const int ld_weight_ic = oc_step * filter_size * filter_size; | |||
const int ld_src_ic = ih * iw; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
float32x4_t c[c_dim][8]; | |||
GI_FLOAT32_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
float32x4_t src[src_reg_size]; | |||
float32x4_t weight[c_dim][filter_size]; | |||
GI_FLOAT32_t src[src_reg_size]; | |||
GI_FLOAT32_t weight[c_dim][filter_size]; | |||
// row 0 | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||
// row 1 | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0); | |||
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( | |||
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); | |||
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); | |||
@@ -711,9 +705,9 @@ struct ConvDirectFp32NchwNchw44 { | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | |||
kern_small_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||
kern_small_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||
bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \ | |||
break; | |||
@@ -731,7 +725,7 @@ struct ConvDirectFp32NchwNchw44 { | |||
ic_step * pack_iw_len; | |||
const int dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44FP32< | |||
KerGiXXs2NchwNchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, big_oc_step, stride, | |||
ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
@@ -760,7 +754,7 @@ struct ConvDirectFp32NchwNchw44 { | |||
ic_step * pack_iw_len; | |||
const int dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44FP32< | |||
KerGiXXs2NchwNchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, oc_step, stride, | |||
ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
@@ -819,7 +813,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||
switch (ow_remain) { | |||
#define cb(step) \ | |||
case step: \ | |||
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ | |||
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \ | |||
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ | |||
break; | |||
@@ -849,7 +843,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||
ic_step * pack_iw_len; | |||
const int dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44FP32< | |||
KerGiXXs2NchwNchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, big_oc_step, | |||
stride, ow_step, CpuTag::A7_TAG>:: | |||
impl(src + src_offset, filter + weight_offset, | |||
@@ -878,7 +872,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> { | |||
ic_step * pack_iw_len; | |||
const int dst_offset = | |||
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
KerNeonXXs2NchwNchw44FP32< | |||
KerGiXXs2NchwNchw44FP32< | |||
bias_mode, Op, ow_step, filter_size, big_oc_step, | |||
stride, ow_step>:: | |||
impl(src + src_offset, filter + weight_offset, |
@@ -0,0 +1,723 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <algorithm> | |||
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h" | |||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs1) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
using namespace fp32; | |||
using namespace conv_stride1; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride1::do_conv_2x2_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
//! unroll of 2 | |||
size_t ic = 0; | |||
for (; ic + 1 < IC; ic += 2) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
const float* src_ptr1 = src_ptr + IW * IH; | |||
float* outptr = dst; | |||
const float* r00 = src_ptr; | |||
const float* r01 = src_ptr + IW; | |||
const float* r10 = src_ptr1; | |||
const float* r11 = src_ptr1 + IW; | |||
const float* k0 = filter + ic * 4; | |||
const float* k1 = k0 + 4; | |||
GI_FLOAT32_t _k0 = GiLoadFloat32(k0); | |||
GI_FLOAT32_t _k1 = GiLoadFloat32(k1); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _r000 = GiLoadFloat32(r00); | |||
GI_FLOAT32_t _r010 = GiLoadFloat32(r01); | |||
GI_FLOAT32_t _r001 = GiLoadFloat32(r00 + 1); | |||
GI_FLOAT32_t _r011 = GiLoadFloat32(r01 + 1); | |||
GI_FLOAT32_t _r100 = GiLoadFloat32(r10); | |||
GI_FLOAT32_t _r110 = GiLoadFloat32(r11); | |||
GI_FLOAT32_t _r101 = GiLoadFloat32(r10 + 1); | |||
GI_FLOAT32_t _r111 = GiLoadFloat32(r11 + 1); | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r000, _k0, 0); | |||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r001, _k0, 1); | |||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r010, _k0, 0); | |||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r011, _k0, 1); | |||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r100, _k1, 0); | |||
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r101, _k1, 1); | |||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r110, _k1, 0); | |||
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r111, _k1, 1); | |||
GiStoreFloat32(outptr, _sum); | |||
r00 += 4; | |||
r01 += 4; | |||
r10 += 4; | |||
r11 += 4; | |||
outptr += 4; | |||
} | |||
r00 += tail_step; | |||
r01 += tail_step; | |||
r10 += tail_step; | |||
r11 += tail_step; | |||
} | |||
} | |||
for (; ic < IC; ic++) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter + ic * 4; | |||
GI_FLOAT32_t _k0 = GiBroadcastFloat32(k0[0]); | |||
GI_FLOAT32_t _k1 = GiBroadcastFloat32(k0[1]); | |||
GI_FLOAT32_t _k2 = GiBroadcastFloat32(k0[2]); | |||
GI_FLOAT32_t _k3 = GiBroadcastFloat32(k0[3]); | |||
rep(h, OH) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r01 = GiLoadFloat32(r0 + 1); | |||
GI_FLOAT32_t _r11 = GiLoadFloat32(r1 + 1); | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _sum2; | |||
_sum = GiMlaqFloat32(_sum, _r00, _k0); | |||
_sum2 = GiMultiplyFloat32(_r01, _k1); | |||
_sum = GiMlaqFloat32(_sum, _r10, _k2); | |||
_sum2 = GiMlaqFloat32(_sum2, _r11, _k3); | |||
_sum = GiAddFloat32(_sum, _sum2); | |||
GiStoreFloat32(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
} | |||
} | |||
void conv_stride1::do_conv_3x3_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); | |||
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); | |||
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); | |||
GI_FLOAT32_t _sum3 = GiLoadFloat32(outptr2); | |||
GI_FLOAT32_t _sum4 = GiBroadcastFloat32(0.f); | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); | |||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); | |||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); | |||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); | |||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); | |||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||
GI_FLOAT32_t _r30n = GiLoadFloat32LowHalf(r3 + 4); | |||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r30n, 1); | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r30n, 2); | |||
_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); | |||
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); | |||
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); | |||
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); | |||
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); | |||
_sum3 = GiSimdFmaLane(_sum3, _r10, _k0123, 0); | |||
_sum4 = GiSimdFmaLane(_sum4, _r11, _k0123, 1); | |||
_sum3 = GiSimdFmaLane(_sum3, _r12, _k0123, 2); | |||
_sum4 = GiSimdFmaLane(_sum4, _r20, _k3456, 0); | |||
_sum3 = GiSimdFmaLane(_sum3, _r21, _k3456, 1); | |||
_sum4 = GiSimdFmaLane(_sum4, _r22, _k3456, 2); | |||
_sum3 = GiSimdFmaLane(_sum3, _r30, _k6789, 0); | |||
_sum4 = GiSimdFmaLane(_sum4, _r31, _k6789, 1); | |||
_sum3 = GiSimdFmaLane(_sum3, _r32, _k6789, 2); | |||
_sum1 = GiAddFloat32(_sum1, _sum2); | |||
_sum3 = GiAddFloat32(_sum3, _sum4); | |||
GiStoreFloat32(outptr, _sum1); | |||
GiStoreFloat32(outptr2, _sum3); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f); | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4); | |||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1); | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4); | |||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1); | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2); | |||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4); | |||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1); | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2); | |||
_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1); | |||
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0); | |||
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2); | |||
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1); | |||
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2); | |||
_sum1 = GiAddFloat32(_sum1, _sum2); | |||
GiStoreFloat32(outptr, _sum1); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride1::do_conv_5x5_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
float* outptr2 = outptr + OW; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); | |||
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); | |||
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); | |||
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); | |||
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); | |||
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); | |||
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); | |||
size_t h = 0; | |||
for (; h + 1 < OH; h += 2) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _sum2 = GiLoadFloat32(outptr2); | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); | |||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); | |||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||
GI_FLOAT32_t _r50 = GiLoadFloat32(r5); | |||
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); | |||
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); | |||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); | |||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); | |||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r10, _k0123, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r11, _k0123, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r12, _k0123, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r13, _k0123, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r14, _k4567, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r20, _k4567, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r21, _k4567, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r22, _k4567, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r23, _k891011, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r24, _k891011, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r30, _k891011, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r31, _k891011, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r32, _k12131415, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r33, _k12131415, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r34, _k12131415, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r40, _k12131415, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r41, _k16171819, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r42, _k16171819, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r43, _k16171819, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r44, _k16171819, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r50, _k20212223, 0); | |||
_sum2 = GiSimdFmaLane(_sum2, _r51, _k20212223, 1); | |||
_sum2 = GiSimdFmaLane(_sum2, _r52, _k20212223, 2); | |||
_sum2 = GiSimdFmaLane(_sum2, _r53, _k20212223, 3); | |||
_sum2 = GiSimdFmaLane(_sum2, _r54, _k24242424, 0); | |||
GiStoreFloat32(outptr, _sum); | |||
GiStoreFloat32(outptr2, _sum2); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
outptr += 4; | |||
outptr2 += 4; | |||
} | |||
r0 += tail_step + IW; | |||
r1 += tail_step + IW; | |||
r2 += tail_step + IW; | |||
r3 += tail_step + IW; | |||
r4 += tail_step + IW; | |||
r5 += tail_step + IW; | |||
outptr += OW; | |||
outptr2 += OW; | |||
} | |||
for (; h < OH; h++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); | |||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); | |||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); | |||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||
GiStoreFloat32(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride1::do_conv_7x7_stride1( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - OW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int width = OW >> 2; | |||
rep(i, width) { | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); | |||
GI_FLOAT32_t _r00 = GiLoadFloat32(r0); // 0 1 2 3 | |||
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); // 4 5 6 7 | |||
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 8); // 8 9 10 11 | |||
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); // 1 2 3 4 | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); // 2 3 4 5 | |||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); // 3 4 5 6 | |||
GI_FLOAT32_t _r05 = GiExtqFloat32(_r04, _r00n, 1); // 5 6 7 8 | |||
GI_FLOAT32_t _r06 = GiExtqFloat32(_r04, _r00n, 2); // 6 7 8 9 | |||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); | |||
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); | |||
GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); | |||
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); | |||
GI_FLOAT32_t _r10 = GiLoadFloat32(r1); | |||
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4); | |||
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 8); | |||
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1); | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2); | |||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3); | |||
GI_FLOAT32_t _r15 = GiExtqFloat32(_r14, _r10n, 1); | |||
GI_FLOAT32_t _r16 = GiExtqFloat32(_r14, _r10n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); | |||
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); | |||
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); | |||
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); | |||
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); | |||
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); | |||
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); | |||
GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); | |||
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); | |||
GI_FLOAT32_t _r20 = GiLoadFloat32(r2); | |||
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4); | |||
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 8); | |||
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1); | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2); | |||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3); | |||
GI_FLOAT32_t _r25 = GiExtqFloat32(_r24, _r20n, 1); | |||
GI_FLOAT32_t _r26 = GiExtqFloat32(_r24, _r20n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); | |||
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); | |||
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); | |||
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); | |||
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); | |||
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); | |||
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); | |||
GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); | |||
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); | |||
GI_FLOAT32_t _r30 = GiLoadFloat32(r3); | |||
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4); | |||
GI_FLOAT32_t _r30n = GiLoadFloat32(r3 + 8); | |||
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1); | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2); | |||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3); | |||
GI_FLOAT32_t _r35 = GiExtqFloat32(_r34, _r30n, 1); | |||
GI_FLOAT32_t _r36 = GiExtqFloat32(_r34, _r30n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); | |||
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); | |||
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); | |||
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); | |||
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); | |||
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); | |||
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); | |||
GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); | |||
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); | |||
GI_FLOAT32_t _r40 = GiLoadFloat32(r4); | |||
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4); | |||
GI_FLOAT32_t _r40n = GiLoadFloat32(r4 + 8); | |||
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1); | |||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2); | |||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3); | |||
GI_FLOAT32_t _r45 = GiExtqFloat32(_r44, _r40n, 1); | |||
GI_FLOAT32_t _r46 = GiExtqFloat32(_r44, _r40n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); | |||
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); | |||
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); | |||
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); | |||
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); | |||
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); | |||
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); | |||
GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); | |||
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); | |||
GI_FLOAT32_t _r50 = GiLoadFloat32(r5); | |||
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4); | |||
GI_FLOAT32_t _r50n = GiLoadFloat32(r5 + 8); | |||
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1); | |||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2); | |||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3); | |||
GI_FLOAT32_t _r55 = GiExtqFloat32(_r54, _r50n, 1); | |||
GI_FLOAT32_t _r56 = GiExtqFloat32(_r54, _r50n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); | |||
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); | |||
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); | |||
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); | |||
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); | |||
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); | |||
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); | |||
GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); | |||
GI_FLOAT32_t _k46474849 = | |||
GiLd1qLaneFloat32(k6 + 4 + 2, GiLoadFloat32LowHalf(k6 + 4), 2); | |||
GI_FLOAT32_t _r60 = GiLoadFloat32(r6); | |||
GI_FLOAT32_t _r64 = GiLoadFloat32(r6 + 4); | |||
GI_FLOAT32_t _r60n = GiLoadFloat32(r6 + 8); | |||
GI_FLOAT32_t _r61 = GiExtqFloat32(_r60, _r64, 1); | |||
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r64, 2); | |||
GI_FLOAT32_t _r63 = GiExtqFloat32(_r60, _r64, 3); | |||
GI_FLOAT32_t _r65 = GiExtqFloat32(_r64, _r60n, 1); | |||
GI_FLOAT32_t _r66 = GiExtqFloat32(_r64, _r60n, 2); | |||
_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); | |||
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); | |||
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); | |||
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); | |||
_sum = GiSimdFmaLane(_sum, _r64, _k46474849, 0); | |||
_sum = GiSimdFmaLane(_sum, _r65, _k46474849, 1); | |||
_sum = GiSimdFmaLane(_sum, _r66, _k46474849, 2); | |||
GiStoreFloat32(outptr, _sum); | |||
r0 += 4; | |||
r1 += 4; | |||
r2 += 4; | |||
r3 += 4; | |||
r4 += 4; | |||
r5 += 4; | |||
r6 += 4; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
#include "src/common/simd_macro/epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -13,7 +13,7 @@ | |||
#include <cstddef> | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace fp32 { | |||
namespace conv_stride1 { | |||
@@ -31,7 +31,7 @@ void do_conv_7x7_stride1( | |||
size_t OH, size_t OW, size_t IC); | |||
} // namespace conv_stride1 | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,503 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <algorithm> | |||
#include "./do_conv_stride2.h" | |||
#include "midout.h" | |||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs2) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
using namespace fp32; | |||
using namespace conv_stride2; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
void conv_stride2::do_conv_2x2_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* k0 = filter; | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
GI_FLOAT32_t _outp = GiLoadFloat32(outptr); | |||
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); | |||
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 | |||
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 | |||
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); | |||
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); | |||
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); | |||
GI_FLOAT32_t _r10 = _r1.val[0]; | |||
GI_FLOAT32_t _r11 = _r1.val[1]; | |||
_outp = GiSimdFmaLane(_outp, _r10, _k0123, 2); | |||
_outp = GiSimdFmaLane(_outp, _r11, _k0123, 3); | |||
GiStoreFloat32(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
} | |||
filter += 4; | |||
} | |||
} | |||
void conv_stride2::do_conv_3x3_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 3; | |||
const float* k2 = filter + 5; | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1); | |||
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2); | |||
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1); | |||
rep(h, OH) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
GI_FLOAT32_t _outp = GiLoadFloat32(outptr); | |||
GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0); | |||
GI_FLOAT32_V2_t _r0n = GiLd2qFloat32(r0 + 8); | |||
GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6 | |||
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7 | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0n.val[0], 1); // 2 4 6 8 | |||
_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0); | |||
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1); | |||
_outp = GiSimdFmaLane(_outp, _r02, _k0123, 2); | |||
GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1); | |||
GI_FLOAT32_V2_t _r1n = GiLd2qFloat32(r1 + 8); | |||
GI_FLOAT32_t _r10 = _r1.val[0]; | |||
GI_FLOAT32_t _r11 = _r1.val[1]; | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1n.val[0], 1); | |||
_outp = GiSimdFmaLane(_outp, _r10, _k3456, 0); | |||
_outp = GiSimdFmaLane(_outp, _r11, _k3456, 1); | |||
_outp = GiSimdFmaLane(_outp, _r12, _k3456, 2); | |||
GI_FLOAT32_V2_t _r2 = GiLd2qFloat32(r2); | |||
GI_FLOAT32_V2_t _r2n = GiLd2qFloat32(r2 + 8); | |||
GI_FLOAT32_t _r20 = _r2.val[0]; | |||
GI_FLOAT32_t _r21 = _r2.val[1]; | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2n.val[0], 1); | |||
_outp = GiSimdFmaLane(_outp, _r20, _k6789, 0); | |||
_outp = GiSimdFmaLane(_outp, _r21, _k6789, 1); | |||
_outp = GiSimdFmaLane(_outp, _r22, _k6789, 2); | |||
GiStoreFloat32(outptr, _outp); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
} | |||
filter += 9; | |||
} | |||
} | |||
void conv_stride2::do_conv_5x5_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(filter); | |||
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4); | |||
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8); | |||
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12); | |||
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16); | |||
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20); | |||
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]); | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); | |||
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); | |||
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); | |||
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); | |||
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; | |||
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; | |||
GI_FLOAT32_t _r10 = _r10_02461357.val[0]; | |||
GI_FLOAT32_t _r11 = _r10_02461357.val[1]; | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); | |||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); | |||
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); | |||
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); | |||
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); | |||
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; | |||
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; | |||
GI_FLOAT32_t _r20 = _r20_02461357.val[0]; | |||
GI_FLOAT32_t _r21 = _r20_02461357.val[1]; | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); | |||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); | |||
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); | |||
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); | |||
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); | |||
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; | |||
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; | |||
GI_FLOAT32_t _r30 = _r30_02461357.val[0]; | |||
GI_FLOAT32_t _r31 = _r30_02461357.val[1]; | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); | |||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); | |||
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); | |||
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); | |||
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); | |||
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; | |||
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; | |||
GI_FLOAT32_t _r40 = _r40_02461357.val[0]; | |||
GI_FLOAT32_t _r41 = _r40_02461357.val[1]; | |||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); | |||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); | |||
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); | |||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||
_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1); | |||
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2); | |||
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3); | |||
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0); | |||
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1); | |||
_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2); | |||
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3); | |||
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0); | |||
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1); | |||
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2); | |||
_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3); | |||
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0); | |||
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1); | |||
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2); | |||
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3); | |||
_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0); | |||
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1); | |||
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2); | |||
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3); | |||
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0); | |||
GiStoreFloat32(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
} | |||
filter += 25; | |||
} | |||
} | |||
void conv_stride2::do_conv_7x7_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
rep(ic, IC) { | |||
const float* src_ptr = src + IW * IH * ic; | |||
float* outptr = dst; | |||
const float* r0 = src_ptr; | |||
const float* r1 = src_ptr + IW; | |||
const float* r2 = src_ptr + IW * 2; | |||
const float* r3 = src_ptr + IW * 3; | |||
const float* r4 = src_ptr + IW * 4; | |||
const float* r5 = src_ptr + IW * 5; | |||
const float* r6 = src_ptr + IW * 6; | |||
const float* k0 = filter; | |||
const float* k1 = filter + 7; | |||
const float* k2 = filter + 14; | |||
const float* k3 = filter + 21; | |||
const float* k4 = filter + 28; | |||
const float* k5 = filter + 35; | |||
const float* k6 = filter + 42; | |||
for (size_t i = 0; i < OH; i++) { | |||
int nn = OW >> 2; | |||
rep(i, nn) { | |||
GI_FLOAT32_t _sum = GiLoadFloat32(outptr); | |||
GI_FLOAT32_t _k0123 = GiLoadFloat32(k0); | |||
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4); | |||
GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0); | |||
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8); | |||
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14 | |||
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15 | |||
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6 | |||
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7 | |||
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8 | |||
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9 | |||
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10 | |||
GI_FLOAT32_t _r05 = GiExtqFloat32(_r01, _r0_9111315, 2); // 5 7 9 11 | |||
GI_FLOAT32_t _r06 = GiExtqFloat32(_r00, _r0_8101214, 3); // 6 8 10 12 | |||
_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0); | |||
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1); | |||
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2); | |||
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3); | |||
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0); | |||
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1); | |||
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2); | |||
GI_FLOAT32_t _k78910 = GiLoadFloat32(k1); | |||
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4); | |||
GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1); | |||
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8); | |||
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0]; | |||
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1]; | |||
GI_FLOAT32_t _r10 = _r10_02461357.val[0]; | |||
GI_FLOAT32_t _r11 = _r10_02461357.val[1]; | |||
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1); | |||
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1); | |||
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2); | |||
GI_FLOAT32_t _r15 = GiExtqFloat32(_r11, _r1_9111315, 2); | |||
GI_FLOAT32_t _r16 = GiExtqFloat32(_r10, _r1_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0); | |||
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1); | |||
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2); | |||
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3); | |||
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0); | |||
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1); | |||
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2); | |||
GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2); | |||
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4); | |||
GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2); | |||
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8); | |||
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0]; | |||
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1]; | |||
GI_FLOAT32_t _r20 = _r20_02461357.val[0]; | |||
GI_FLOAT32_t _r21 = _r20_02461357.val[1]; | |||
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1); | |||
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1); | |||
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2); | |||
GI_FLOAT32_t _r25 = GiExtqFloat32(_r21, _r2_9111315, 2); | |||
GI_FLOAT32_t _r26 = GiExtqFloat32(_r20, _r2_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0); | |||
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1); | |||
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2); | |||
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3); | |||
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0); | |||
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1); | |||
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2); | |||
GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3); | |||
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4); | |||
GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3); | |||
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8); | |||
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0]; | |||
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1]; | |||
GI_FLOAT32_t _r30 = _r30_02461357.val[0]; | |||
GI_FLOAT32_t _r31 = _r30_02461357.val[1]; | |||
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1); | |||
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1); | |||
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2); | |||
GI_FLOAT32_t _r35 = GiExtqFloat32(_r31, _r3_9111315, 2); | |||
GI_FLOAT32_t _r36 = GiExtqFloat32(_r30, _r3_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0); | |||
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1); | |||
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2); | |||
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3); | |||
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0); | |||
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1); | |||
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2); | |||
GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4); | |||
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4); | |||
GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4); | |||
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8); | |||
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0]; | |||
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1]; | |||
GI_FLOAT32_t _r40 = _r40_02461357.val[0]; | |||
GI_FLOAT32_t _r41 = _r40_02461357.val[1]; | |||
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1); | |||
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1); | |||
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2); | |||
GI_FLOAT32_t _r45 = GiExtqFloat32(_r41, _r4_9111315, 2); | |||
GI_FLOAT32_t _r46 = GiExtqFloat32(_r40, _r4_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0); | |||
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1); | |||
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2); | |||
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3); | |||
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0); | |||
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1); | |||
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2); | |||
GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5); | |||
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4); | |||
GI_FLOAT32_V2_t _r50_02461357 = GiLd2qFloat32(r5); | |||
GI_FLOAT32_V2_t _r50nx2 = GiLd2qFloat32(r5 + 8); | |||
GI_FLOAT32_t _r5_8101214 = _r50nx2.val[0]; | |||
GI_FLOAT32_t _r5_9111315 = _r50nx2.val[1]; | |||
GI_FLOAT32_t _r50 = _r50_02461357.val[0]; | |||
GI_FLOAT32_t _r51 = _r50_02461357.val[1]; | |||
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r5_8101214, 1); | |||
GI_FLOAT32_t _r53 = GiExtqFloat32(_r51, _r5_9111315, 1); | |||
GI_FLOAT32_t _r54 = GiExtqFloat32(_r50, _r5_8101214, 2); | |||
GI_FLOAT32_t _r55 = GiExtqFloat32(_r51, _r5_9111315, 2); | |||
GI_FLOAT32_t _r56 = GiExtqFloat32(_r50, _r5_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0); | |||
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1); | |||
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2); | |||
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3); | |||
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0); | |||
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1); | |||
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2); | |||
GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6); | |||
GI_FLOAT32_t _k45464748 = GiLoadFloat32(k6 + 3); | |||
GI_FLOAT32_V2_t _r60_02461357 = GiLd2qFloat32(r6); | |||
GI_FLOAT32_V2_t _r60nx2 = GiLd2qFloat32(r6 + 8); | |||
GI_FLOAT32_t _r6_8101214 = _r60nx2.val[0]; | |||
GI_FLOAT32_t _r6_9111315 = _r60nx2.val[1]; | |||
GI_FLOAT32_t _r60 = _r60_02461357.val[0]; | |||
GI_FLOAT32_t _r61 = _r60_02461357.val[1]; | |||
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r6_8101214, 1); | |||
GI_FLOAT32_t _r63 = GiExtqFloat32(_r61, _r6_9111315, 1); | |||
GI_FLOAT32_t _r64 = GiExtqFloat32(_r60, _r6_8101214, 2); | |||
GI_FLOAT32_t _r65 = GiExtqFloat32(_r61, _r6_9111315, 2); | |||
GI_FLOAT32_t _r66 = GiExtqFloat32(_r60, _r6_8101214, 3); | |||
_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0); | |||
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1); | |||
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2); | |||
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3); | |||
_sum = GiSimdFmaLane(_sum, _r64, _k45464748, 1); | |||
_sum = GiSimdFmaLane(_sum, _r65, _k45464748, 2); | |||
_sum = GiSimdFmaLane(_sum, _r66, _k45464748, 3); | |||
GiStoreFloat32(outptr, _sum); | |||
r0 += 8; | |||
r1 += 8; | |||
r2 += 8; | |||
r3 += 8; | |||
r4 += 8; | |||
r5 += 8; | |||
r6 += 8; | |||
outptr += 4; | |||
} | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
r2 += tail_step; | |||
r3 += tail_step; | |||
r4 += tail_step; | |||
r5 += tail_step; | |||
r6 += tail_step; | |||
} | |||
filter += 49; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -13,7 +13,7 @@ | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace fp32 { | |||
namespace conv_stride2 { | |||
void do_conv_2x2_stride2( | |||
@@ -30,7 +30,7 @@ void do_conv_7x7_stride2( | |||
size_t OH, size_t OW, size_t IC); | |||
} // namespace conv_stride2 | |||
} // namespace fp32 | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,21 +11,21 @@ | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/block_helper.h" | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/block_helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
using conv_fun = std::function<void( | |||
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) | |||
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw44_stride1) | |||
namespace { | |||
static inline size_t get_perthread_cache_bytes( | |||
@@ -156,7 +156,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_fp32_nchw44_stride1, | |||
megdnn_fallback_conv_bias_fp32_nchw44_stride1, | |||
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
@@ -175,7 +175,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_k | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ | |||
megdnn_fallback_conv_bias_fp32_nchw44_stride1, \ | |||
midout_iv(#filter #bias_mode #stride #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | |||
} \ |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_stride1_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -10,10 +10,10 @@ | |||
* implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace conv_bias { | |||
template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||
@@ -28,5 +28,5 @@ void pack_src_fp32_nchw44( | |||
const int pad_top, const int pad_bottom, const int ic, const int ic_stride); | |||
} // namespace conv_bias | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn |
@@ -1,6 +1,6 @@ | |||
/** | |||
* \file | |||
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp | |||
dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -12,21 +12,21 @@ | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/conv_bias/fp32/algos.h" | |||
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/common/nchw_nchwxx_valid.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
#include "midout.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
using conv_fun = std::function<void( | |||
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, | |||
const CpuNDRange& ncb_range)>; | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44) | |||
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw_nchw44) | |||
namespace { | |||
static inline int block_helper( | |||
const int nthread, const int amount, const int per_unit_bytes) { | |||
@@ -195,7 +195,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_fp32_nchw_nchw44, | |||
megdnn_fallback_conv_bias_fp32_nchw_nchw44, | |||
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
@@ -214,7 +214,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHWNCHW44:: | |||
// shape runtime | |||
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ | |||
megdnn_fallback_conv_bias_fp32_nchw_nchw44, \ | |||
midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | |||
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | |||
} \ |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,15 +11,14 @@ | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/arm_common/elemwise_helper/elemwise_op.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/gi/intrinsic_helper.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/elemwise_helper/elemwise_op.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace fp32_direct_nchw_nchw44 { | |||
static inline void pack_weight_fp32_nchw_nchw44( | |||
@@ -34,8 +33,8 @@ static inline void pack_weight_fp32_nchw_nchw44( | |||
for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { | |||
for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { | |||
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||
float32x4_t vsrc = vld1q_f32(in_ptr_oc); | |||
vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); | |||
GI_FLOAT32_t vsrc = GiLoadFloat32(in_ptr_oc); | |||
GiStoreFloat32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); | |||
in_ptr_oc += oc_step; | |||
} | |||
dst_ptr_oc += oc_step; | |||
@@ -51,6 +50,6 @@ void conv_direct_fp32_nchw_nchw44( | |||
const int, const int); | |||
} // namespace fp32_direct_nchw_nchw44 | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,14 +11,13 @@ | |||
#pragma once | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | |||
struct FilterTransform6X3 { | |||
@@ -65,8 +64,8 @@ struct FilterTransform6X3 { | |||
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3); | |||
Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1); | |||
float32x4_t zeros = vdupq_n_f32(0.0f); | |||
g2.value = vextq_f32(g2.value, zeros, 1); | |||
GI_FLOAT32_t zeros = GiZeroFloat32(); | |||
g2.value = GiExtqFloat32(g2.value, zeros, 1); | |||
#define cb(i) Vector<float, 4> wd##i; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
@@ -106,7 +105,6 @@ struct FilterTransform6X3 { | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||
@@ -128,7 +126,7 @@ struct FilterTransform6X3 { | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
@@ -154,7 +152,7 @@ struct FilterTransform6X3 { | |||
#undef FILTER_TRANSFORM | |||
#undef GET_VECTOR_ELEM | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,196 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/unroll_macro.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||
GI_FLOAT32_V2_t a0, a1; | |||
a0.val[0] = GiLoadFloat32(src + 0 * lda); | |||
a0.val[1] = GiLoadFloat32(src + 1 * lda); | |||
a1.val[0] = GiLoadFloat32(src + 2 * lda); | |||
a1.val[1] = GiLoadFloat32(src + 3 * lda); | |||
GI_FLOAT32_V2_t b0 = GiZipqFloat32(a0.val[0], a1.val[0]); | |||
GI_FLOAT32_V2_t b1 = GiZipqFloat32(a0.val[1], a1.val[1]); | |||
GI_FLOAT32_V2_t c0 = GiZipqFloat32(b0.val[0], b1.val[0]); | |||
GI_FLOAT32_V2_t c1 = GiZipqFloat32(b0.val[1], b1.val[1]); | |||
GiStoreFloat32(dst + 0 * ldb, c0.val[0]); | |||
GiStoreFloat32(dst + 1 * ldb, c0.val[1]); | |||
GiStoreFloat32(dst + 2 * ldb, c1.val[0]); | |||
GiStoreFloat32(dst + 3 * ldb, c1.val[1]); | |||
} | |||
} // namespace fallback | |||
} // namespace megdnn | |||
#define MATRIX_MUL4x4(sum, a, b) \ | |||
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##0, a##0, 0); \ | |||
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##1, a##0, 1); \ | |||
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##2, a##0, 2); \ | |||
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##3, a##0, 3); \ | |||
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##0, a##1, 0); \ | |||
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##1, a##1, 1); \ | |||
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##2, a##1, 2); \ | |||
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##3, a##1, 3); \ | |||
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##0, a##2, 0); \ | |||
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##1, a##2, 1); \ | |||
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##2, a##2, 2); \ | |||
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##3, a##2, 3); \ | |||
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##0, a##3, 0); \ | |||
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##1, a##3, 1); \ | |||
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##2, a##3, 2); \ | |||
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##3, a##3, 3); | |||
#define CONCAT(a, idx) a##idx | |||
#if MEGDNN_AARCH64 | |||
//! ret and a are type Vector<float, 8> | |||
#define TRANSPOSE_8x8(a, ret) \ | |||
do { \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ | |||
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ | |||
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ | |||
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ | |||
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b6.val[1]))); \ | |||
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[0]))); \ | |||
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b5.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | |||
} while (0); | |||
#define TRANSPOSE_8x3(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[0]))); \ | |||
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[0]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); \ | |||
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b1.val[1]))); \ | |||
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||
GiReinterpretqFloat32ToS64(b2.val[1]), \ | |||
GiReinterpretqFloat32ToS64(b3.val[1]))); | |||
#else | |||
#define TRANSPOSE_8x4(a, ret) \ | |||
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ | |||
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ | |||
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ | |||
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ | |||
CONCAT(ret, 0).value.val[0] = \ | |||
GiCombineFloat32(GiGetLowFloat32(b0.val[0]), GiGetLowFloat32(b1.val[0])); \ | |||
CONCAT(ret, 1).value.val[0] = GiCombineFloat32( \ | |||
GiGetHighFloat32(b0.val[0]), GiGetHighFloat32(b1.val[0])); \ | |||
CONCAT(ret, 2).value.val[0] = \ | |||
GiCombineFloat32(GiGetLowFloat32(b0.val[1]), GiGetLowFloat32(b1.val[1])); \ | |||
CONCAT(ret, 3).value.val[0] = GiCombineFloat32( \ | |||
GiGetHighFloat32(b0.val[1]), GiGetHighFloat32(b1.val[1])); \ | |||
CONCAT(ret, 0).value.val[1] = \ | |||
GiCombineFloat32(GiGetLowFloat32(b2.val[0]), GiGetLowFloat32(b3.val[0])); \ | |||
CONCAT(ret, 1).value.val[1] = GiCombineFloat32( \ | |||
GiGetHighFloat32(b2.val[0]), GiGetHighFloat32(b3.val[0])); \ | |||
CONCAT(ret, 2).value.val[1] = \ | |||
GiCombineFloat32(GiGetLowFloat32(b2.val[1]), GiGetLowFloat32(b3.val[1])); \ | |||
CONCAT(ret, 3).value.val[1] = GiCombineFloat32( \ | |||
GiGetHighFloat32(b2.val[1]), GiGetHighFloat32(b3.val[1])); | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -11,14 +11,15 @@ | |||
#pragma once | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/gi/postprocess_helper.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, winograd_2x3_4x4_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY( | |||
float, float, float, float, 2, 3, 4, 4, winograd_gi_2x3_4x4_f) | |||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) | |||
@@ -37,7 +38,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY( | |||
MEGDNN_REG_WINOGRAD_STRATEGY( | |||
float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F23) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
struct InputTransform2X3 { | |||
@@ -40,15 +39,15 @@ struct InputTransform2X3 { | |||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
auto v0 = vld1q_f32(input_ptr); | |||
auto v1 = vld1q_f32(input_ptr + IW); | |||
auto v2 = vld1q_f32(input_ptr + IW * 2); | |||
auto v3 = vld1q_f32(input_ptr + IW * 3); | |||
vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0); | |||
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1); | |||
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2); | |||
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3); | |||
auto v0 = GiLoadFloat32(input_ptr); | |||
auto v1 = GiLoadFloat32(input_ptr + IW); | |||
auto v2 = GiLoadFloat32(input_ptr + IW * 2); | |||
auto v3 = GiLoadFloat32(input_ptr + IW * 3); | |||
GiStoreFloat32(patch + ico * 4 * alpha + 0 * 4, v0); | |||
GiStoreFloat32(patch + ico * 4 * alpha + 1 * 4, v1); | |||
GiStoreFloat32(patch + ico * 4 * alpha + 2 * 4, v2); | |||
GiStoreFloat32(patch + ico * 4 * alpha + 3 * 4, v3); | |||
input_ptr += IH * IW; | |||
} | |||
} | |||
@@ -197,18 +196,18 @@ struct OutputTransform2X3 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) | |||
void winograd_2x3_4x4_f::filter( | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_gi_2x3_4x4_f) | |||
void winograd_gi_2x3_4x4_f::filter( | |||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | |||
size_t OC, size_t IC, size_t oc_start, size_t oc_end) { | |||
constexpr int alpha = 2 + 3 - 1; | |||
//! G * g * GT | |||
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||
GI_FLOAT32_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, | |||
g3{0, 0, 1, 0}; | |||
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||
GI_FLOAT32_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1}, | |||
gt3{0, 0, 0, 0}; | |||
size_t OCB = OC / 4; | |||
size_t ICB = IC / 4; | |||
@@ -225,33 +224,33 @@ void winograd_2x3_4x4_f::filter( | |||
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0 | |||
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1 | |||
//! 0 0 1 0 0 0 0 0 0 0 0 0 | |||
float32x4_t vf0 = vld1q_f32(filter_ptr); | |||
float32x4_t vf1 = vld1q_f32(filter_ptr + 4); | |||
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]); | |||
float32x4_t v3(vdupq_n_f32(0)); | |||
auto vtmp = vextq_f32(vf1, vf2, 2); | |||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||
float32x4_t v2(vtmp); | |||
vtmp = vextq_f32(vf0, vf1, 3); | |||
vtmp = vsetq_lane_f32(0, vtmp, 3); | |||
float32x4_t v1(vtmp); | |||
vtmp = vsetq_lane_f32(0, vf0, 3); | |||
float32x4_t v0(vtmp); | |||
float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0), | |||
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0); | |||
GI_FLOAT32_t vf0 = GiLoadFloat32(filter_ptr); | |||
GI_FLOAT32_t vf1 = GiLoadFloat32(filter_ptr + 4); | |||
GI_FLOAT32_t vf2 = GiBroadcastFloat32(filter_ptr[8]); | |||
GI_FLOAT32_t v3(GiBroadcastFloat32(0)); | |||
auto vtmp = GiExtqFloat32(vf1, vf2, 2); | |||
vtmp = GiSetqLaneFloat32(0, vtmp, 3); | |||
GI_FLOAT32_t v2(vtmp); | |||
vtmp = GiExtqFloat32(vf0, vf1, 3); | |||
vtmp = GiSetqLaneFloat32(0, vtmp, 3); | |||
GI_FLOAT32_t v1(vtmp); | |||
vtmp = GiSetqLaneFloat32(0, vf0, 3); | |||
GI_FLOAT32_t v0(vtmp); | |||
GI_FLOAT32_t vsum0 = GiBroadcastFloat32(0), vsum1 = GiBroadcastFloat32(0), | |||
vsum2 = GiBroadcastFloat32(0), vsum3 = GiBroadcastFloat32(0); | |||
MATRIX_MUL4x4(vsum, g, v); | |||
float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0), | |||
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0); | |||
GI_FLOAT32_t vres0 = GiBroadcastFloat32(0), vres1 = GiBroadcastFloat32(0), | |||
vres2 = GiBroadcastFloat32(0), vres3 = GiBroadcastFloat32(0); | |||
MATRIX_MUL4x4(vres, vsum, gt); | |||
vst1q_f32(transform_mid_buf, vres0); | |||
vst1q_f32(transform_mid_buf + 4, vres1); | |||
vst1q_f32(transform_mid_buf + 8, vres2); | |||
vst1q_f32(transform_mid_buf + 12, vres3); | |||
GiStoreFloat32(transform_mid_buf, vres0); | |||
GiStoreFloat32(transform_mid_buf + 4, vres1); | |||
GiStoreFloat32(transform_mid_buf + 8, vres2); | |||
GiStoreFloat32(transform_mid_buf + 12, vres3); | |||
size_t ocb = oc / 4; | |||
size_t oc4 = oc % 4; | |||
@@ -266,7 +265,7 @@ void winograd_2x3_4x4_f::filter( | |||
} | |||
} | |||
void winograd_2x3_4x4_f::input( | |||
void winograd_gi_2x3_4x4_f::input( | |||
const float* input, float* input_transform_buf, float* transform_mid_buf, | |||
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
@@ -304,7 +303,7 @@ void winograd_2x3_4x4_f::input( | |||
} | |||
} | |||
void winograd_2x3_4x4_f::output( | |||
void winograd_gi_2x3_4x4_f::output( | |||
const float* output_transform_buf, const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||
@@ -322,8 +321,8 @@ void winograd_2x3_4x4_f::output( | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F23, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -333,7 +332,7 @@ void winograd_2x3_4x4_f::output( | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F45) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
struct FilterTransform4X5 { | |||
@@ -126,9 +125,9 @@ struct FilterTransform4X5 { | |||
#undef cb | |||
FILTER_TRANSFORM(g, Gg) | |||
float32x4x2_t vgr; | |||
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||
GI_FLOAT32_V2_t vgr; | |||
GI_FLOAT32_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3}; | |||
GI_FLOAT32_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7}; | |||
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3}; | |||
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7}; | |||
Vector<float, 8> Ggt4(vgr); | |||
@@ -167,8 +166,10 @@ struct InputTransform4X5 { | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||
template <bool inner> | |||
static void transform( | |||
@@ -345,22 +346,22 @@ struct OutputTransform4X5 { | |||
#undef cb | |||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||
float32x4_t bias0; | |||
GI_FLOAT32_t bias0; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
bias0 = GiBroadcastFloat32(bias[oc]); | |||
} | |||
rep(i, 4) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
item0 = GiAddFloat32(item0, bias0); | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
item0 = vaddq_f32(item0, bias0); | |||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||
item0 = GiAddFloat32(item0, bias0); | |||
} | |||
item0 = op(item0); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
mid_buf1 += 4; | |||
} | |||
} else { | |||
@@ -388,7 +389,7 @@ struct OutputTransform4X5 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) | |||
@@ -448,8 +449,8 @@ void winograd_4x5_1x1_f::output( | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F45, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -459,7 +460,7 @@ void winograd_4x5_1x1_f::output( | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F54) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
struct FilterTransform5X4 { | |||
@@ -94,7 +93,6 @@ struct FilterTransform5X4 { | |||
transform_mid_buf[j * alpha + i]; | |||
} | |||
#else | |||
#define cb(i) \ | |||
do { \ | |||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ | |||
@@ -117,7 +115,7 @@ struct FilterTransform5X4 { | |||
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ | |||
mid_buf1 += 8; \ | |||
} while (0); | |||
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) | |||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value) | |||
float* mid_buf1 = transform_mid_buf; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
@@ -154,8 +152,10 @@ struct InputTransform5X4 { | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||
template <bool inner> | |||
static void transform( | |||
@@ -348,29 +348,29 @@ struct OutputTransform5X4 { | |||
#undef cb | |||
if (oh_start + 5 <= OH && ow_start + 5 <= OW) { | |||
float32x4_t bias0; | |||
GI_FLOAT32_t bias0; | |||
float32_t bias1; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
bias0 = GiBroadcastFloat32(bias[oc]); | |||
bias1 = bias[oc]; | |||
} | |||
rep(i, 5) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||
float32_t item1 = mid_buf1[4]; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
item0 = GiAddFloat32(item0, bias0); | |||
item1 = item1 + bias1; | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4]; | |||
item0 = vaddq_f32(item0, bias0); | |||
item0 = GiAddFloat32(item0, bias0); | |||
item1 = item1 + bias1; | |||
} | |||
item0 = op(item0); | |||
item1 = op(item1); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
output[oc * OH * OW + oh * OW + ow_start + 4] = item1; | |||
mid_buf1 += 5; | |||
@@ -400,7 +400,7 @@ struct OutputTransform5X4 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) | |||
@@ -461,8 +461,8 @@ void winograd_5x4_1x1_f::output( | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F54, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -472,7 +472,7 @@ void winograd_5x4_1x1_f::output( | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
/** | |||
@@ -57,8 +56,10 @@ namespace { | |||
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ | |||
} while (0); | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) | |||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1]) | |||
#define GET_VECTOR_LOW_ELEM(s, i, idx) \ | |||
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0]) | |||
struct InputTransform6X3 { | |||
template <bool inner> | |||
static void transform( | |||
@@ -271,31 +272,31 @@ struct OutputTransform6X3 { | |||
#undef cb | |||
if (oh_start + 6 <= OH && ow_start + 6 <= OW) { | |||
float32x4_t bias0; | |||
GI_FLOAT32_t bias0; | |||
float32x2_t bias1; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
bias0 = vdupq_n_f32(bias[oc]); | |||
bias1 = vdup_n_f32(bias[oc]); | |||
bias0 = GiBroadcastFloat32(bias[oc]); | |||
bias1 = GiDupFloat32(bias[oc]); | |||
} | |||
rep(i, 6) { | |||
size_t oh = oh_start + i; | |||
float32x4_t item0 = vld1q_f32(mid_buf1); | |||
float32x2_t item1 = vld1_f32(mid_buf1 + 4); | |||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||
float32x2_t item1 = GiLdFloat32(mid_buf1 + 4); | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = vadd_f32(item1, bias1); | |||
item0 = GiAddFloat32(item0, bias0); | |||
item1 = GiAddDFloat32(item1, bias1); | |||
} else if (bmode == BiasMode::BIAS) { | |||
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + 4); | |||
item0 = vaddq_f32(item0, bias0); | |||
item1 = vadd_f32(item1, bias1); | |||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||
bias1 = GiLdFloat32(bias + oc * OH * OW + oh * OW + ow_start + 4); | |||
item0 = GiAddFloat32(item0, bias0); | |||
item1 = GiAddDFloat32(item1, bias1); | |||
} | |||
item0 = op(item0); | |||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0); | |||
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1); | |||
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 0)), item1, 0); | |||
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 1)), item1, 1); | |||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||
GiSt1Float32(output + oc * OH * OW + oh * OW + ow_start + 4, item1); | |||
mid_buf1 += 6; | |||
} | |||
@@ -325,7 +326,7 @@ struct OutputTransform6X3 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) | |||
@@ -385,8 +386,8 @@ void winograd_6x3_1x1_f::output( | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F63, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -396,7 +397,7 @@ void winograd_6x3_1x1_f::output( | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/winograd/winograd_helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_4x4) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
@@ -41,16 +40,16 @@ struct InputTransform6X3 { | |||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||
for (size_t ico = 0; ico < 4; ++ico) { | |||
if (ic + ico < IC) { | |||
#define cb(i) \ | |||
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \ | |||
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4); | |||
#define cb(i) \ | |||
auto v##i##0 = GiLoadFloat32(input_ptr + IW * i); \ | |||
auto v##i##1 = GiLoadFloat32(input_ptr + IW * i + 4); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) \ | |||
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||
#define cb(i) \ | |||
GiStoreFloat32(patch + ico * 8 * alpha + i * 8, v##i##0); \ | |||
GiStoreFloat32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
@@ -255,7 +254,7 @@ struct OutputTransform6X3 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) | |||
@@ -323,8 +322,8 @@ void winograd_6x3_4x4_f::output( | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F63_4x4, cb, float, float, bmode, | |||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||
nr_units_in_tile, src_dtype, dst_dtype); | |||
@@ -334,7 +333,7 @@ void winograd_6x3_4x4_f::output( | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4) | |||
MIDOUT_DECL(megdnn_fallback_winograd_nchw44_fp32_F23_mk4) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
constexpr size_t alpha = 2 + 3 - 1; | |||
@@ -72,8 +71,9 @@ struct InputTransformF23_NCHW44 { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||
vst1q_f32(patchT + iho * alpha * pack_size + iwo * pack_size, src); | |||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||
GiStoreFloat32( | |||
patchT + iho * alpha * pack_size + iwo * pack_size, src); | |||
} | |||
} | |||
#define cb(m, n) \ | |||
@@ -190,7 +190,7 @@ struct OutputTransformF23_NCHW44 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) | |||
@@ -313,14 +313,14 @@ void winograd_F23_mk4_f_nchw44::output( | |||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | |||
"NCHW44 Winograd filter transform requires OC is times of 4"); | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, | |||
nonline_mode); | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/winograd/winograd_helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_mk4) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_mk4) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
@@ -49,11 +48,11 @@ struct InputTransformF63_NCHW44 { | |||
const float* input_ptr = | |||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | |||
for (size_t ih = 0; ih < alpha; ih++) { | |||
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); | |||
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
input_ptr += IW4; | |||
@@ -68,8 +67,9 @@ struct InputTransformF63_NCHW44 { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||
GiStoreFloat32( | |||
patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||
} | |||
} | |||
} | |||
@@ -83,10 +83,10 @@ struct InputTransformF63_NCHW44 { | |||
size_t ICB = IC / pack_size; | |||
size_t icb = ic / pack_size; | |||
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; | |||
float32x4_t v0 = vld1q_f32(input_parameters + 0); | |||
float32x4_t v1 = vld1q_f32(input_parameters + 4); | |||
float32x4_t v2 = vld1q_f32(input_parameters + 8); | |||
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7; | |||
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); | |||
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); | |||
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); | |||
//! B | |||
//! 1 0 0 0 0 0 0 0 | |||
@@ -98,57 +98,57 @@ struct InputTransformF63_NCHW44 { | |||
//! -1 1 1 1 1 1 1 0 | |||
//! 0 0 0 0 0 0 0 1 | |||
#define cb(i) \ | |||
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||
auto t##i##1 = d6; \ | |||
auto t##i##2 = d6; \ | |||
auto t##i##3 = d6; \ | |||
auto t##i##4 = d6; \ | |||
auto t##i##5 = d6; \ | |||
auto t##i##6 = d6; \ | |||
t##i##0 = t##i##0 - d6; \ | |||
t##i##1 = t##i##1 + d1; \ | |||
t##i##2 = t##i##2 - d1; \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ | |||
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ | |||
t##i##7 = t##i##7 - d1; \ | |||
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ | |||
t##i##1 = t##i##1 + d2; \ | |||
t##i##2 = t##i##2 + d2; \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ | |||
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ | |||
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ | |||
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ | |||
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ | |||
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ | |||
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ | |||
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ | |||
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ | |||
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ | |||
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ | |||
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ | |||
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ | |||
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ | |||
t##i##1 = t##i##1 + d5; \ | |||
t##i##2 = t##i##2 - d5; \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ | |||
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ | |||
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); | |||
#define cb(i) \ | |||
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||
auto t##i##0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||
auto t##i##7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||
auto t##i##1 = d6; \ | |||
auto t##i##2 = d6; \ | |||
auto t##i##3 = d6; \ | |||
auto t##i##4 = d6; \ | |||
auto t##i##5 = d6; \ | |||
auto t##i##6 = d6; \ | |||
t##i##0 = t##i##0 - d6; \ | |||
t##i##1 = t##i##1 + d1; \ | |||
t##i##2 = t##i##2 - d1; \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d1, v0, 2); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d1, v0, 2); \ | |||
t##i##5 = GiSimdFmaLane(t##i##5, d1, v1, 2); \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d1, v1, 2); \ | |||
t##i##7 = t##i##7 - d1; \ | |||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 0); \ | |||
t##i##1 = t##i##1 + d2; \ | |||
t##i##2 = t##i##2 + d2; \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d2, v0, 3); \ | |||
t##i##4 = GiSimdFmaLane(t##i##4, d2, v0, 3); \ | |||
t##i##5 = GiSimdFmaLane(t##i##5, d2, v1, 3); \ | |||
t##i##6 = GiSimdFmaLane(t##i##6, d2, v1, 3); \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d3, v0, 1); \ | |||
t##i##2 = GiSimdFmaLane(t##i##2, d3, v0, 1); \ | |||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d3, v1, 0); \ | |||
t##i##4 = GiSimdFmaLane(t##i##4, d3, v1, 0); \ | |||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d3, v1, 0); \ | |||
t##i##6 = GiSimdFmaLane(t##i##6, d3, v1, 0); \ | |||
t##i##7 = GiSimdFmaLane(t##i##7, d3, v0, 0); \ | |||
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 0); \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d4, v0, 1); \ | |||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d4, v0, 1); \ | |||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v1, 1); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d4, v1, 1); \ | |||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d4, v2, 0); \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d4, v2, 0); \ | |||
t##i##1 = t##i##1 + d5; \ | |||
t##i##2 = t##i##2 - d5; \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d5, v1, 2); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d5, v1, 2); \ | |||
t##i##5 = GiSimdFmaLane(t##i##5, d5, v0, 2); \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v0, 2); \ | |||
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v0, 0); | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
@@ -164,75 +164,75 @@ struct InputTransformF63_NCHW44 { | |||
d0 = d0 - t6##i; \ | |||
d1 = d1 + t1##i; \ | |||
d2 = d2 - t1##i; \ | |||
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ | |||
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ | |||
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ | |||
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ | |||
d3 = GiSimdFmaLane(d3, t1##i, v0, 2); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t1##i, v0, 2); \ | |||
d5 = GiSimdFmaLane(d5, t1##i, v1, 2); \ | |||
d6 = GiFmsqLaneQFloat32(d6, t1##i, v1, 2); \ | |||
d7 = d7 - t1##i; \ | |||
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ | |||
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 0); \ | |||
d1 = d1 + t2##i; \ | |||
d2 = d2 + t2##i; \ | |||
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ | |||
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ | |||
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ | |||
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ | |||
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ | |||
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ | |||
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ | |||
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ | |||
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ | |||
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ | |||
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ | |||
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ | |||
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ | |||
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ | |||
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ | |||
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ | |||
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ | |||
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ | |||
d3 = GiSimdFmaLane(d3, t2##i, v0, 3); \ | |||
d4 = GiSimdFmaLane(d4, t2##i, v0, 3); \ | |||
d5 = GiSimdFmaLane(d5, t2##i, v1, 3); \ | |||
d6 = GiSimdFmaLane(d6, t2##i, v1, 3); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t3##i, v0, 1); \ | |||
d2 = GiSimdFmaLane(d2, t3##i, v0, 1); \ | |||
d3 = GiFmsqLaneQFloat32(d3, t3##i, v1, 0); \ | |||
d4 = GiSimdFmaLane(d4, t3##i, v1, 0); \ | |||
d5 = GiFmsqLaneQFloat32(d5, t3##i, v1, 0); \ | |||
d6 = GiSimdFmaLane(d6, t3##i, v1, 0); \ | |||
d7 = GiSimdFmaLane(d7, t3##i, v0, 0); \ | |||
d0 = GiSimdFmaLane(d0, t4##i, v0, 0); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t4##i, v0, 1); \ | |||
d2 = GiFmsqLaneQFloat32(d2, t4##i, v0, 1); \ | |||
d3 = GiFmsqLaneQFloat32(d3, t4##i, v1, 1); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t4##i, v1, 1); \ | |||
d5 = GiFmsqLaneQFloat32(d5, t4##i, v2, 0); \ | |||
d6 = GiFmsqLaneQFloat32(d6, t4##i, v2, 0); \ | |||
d1 = d1 + t5##i; \ | |||
d2 = d2 - t5##i; \ | |||
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ | |||
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ | |||
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ | |||
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ | |||
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ | |||
vst1q_f32( \ | |||
d3 = GiSimdFmaLane(d3, t5##i, v1, 2); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t5##i, v1, 2); \ | |||
d5 = GiSimdFmaLane(d5, t5##i, v0, 2); \ | |||
d6 = GiFmsqLaneQFloat32(d6, t5##i, v0, 2); \ | |||
d7 = GiFmsqLaneQFloat32(d7, t5##i, v0, 0); \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d0); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d1); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d2); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d3); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d4); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d5); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d6); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
@@ -347,7 +347,7 @@ struct OutputTransformF63_NCHW44 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) | |||
@@ -488,14 +488,14 @@ void winograd_F63_mk4_f_nchw44::output( | |||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | |||
"NCHW44 Winograd filter transform requires OC is times of 4"); | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F63_mk4, cb, float, float, bmode, | |||
nonline_mode); | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp | |||
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
@@ -9,22 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/arm_common/conv_bias/fp32/filter_transform.h" | |||
#include "src/arm_common/conv_bias/fp32/helper.h" | |||
#include "src/arm_common/conv_bias/fp32/strategy.h" | |||
#include "src/arm_common/elemwise_helper/op_unary.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/winograd/winograd_helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||
#include "src/fallback/conv_bias/gi/utils.h" | |||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||
#include "src/fallback/elemwise_helper/op_unary.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4) | |||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F73_mk4) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace fallback; | |||
namespace { | |||
@@ -51,11 +50,11 @@ struct InputTransformF73_NCHW44 { | |||
const float* input_ptr = | |||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | |||
for (size_t ih = 0; ih < alpha; ih++) { | |||
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); | |||
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i); | |||
UNROLL_CALL_NOWRAPPER(9, cb); | |||
#undef cb | |||
#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i); | |||
UNROLL_CALL_NOWRAPPER(9, cb); | |||
#undef cb | |||
input_ptr += IW4; | |||
@@ -70,8 +69,9 @@ struct InputTransformF73_NCHW44 { | |||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); | |||
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size); | |||
GiStoreFloat32( | |||
patchT + iho * pack_size * alpha + iwo * pack_size, src); | |||
} | |||
} | |||
} | |||
@@ -85,14 +85,14 @@ struct InputTransformF73_NCHW44 { | |||
size_t ICB = IC / pack_size; | |||
size_t icb = ic / pack_size; | |||
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8; | |||
float32x4_t v0 = vld1q_f32(input_parameters + 0); | |||
float32x4_t v1 = vld1q_f32(input_parameters + 4); | |||
float32x4_t v2 = vld1q_f32(input_parameters + 8); | |||
float32x4_t v3 = vld1q_f32(input_parameters + 12); | |||
float32x4_t v4 = vld1q_f32(input_parameters + 16); | |||
float32x4_t v5 = vld1q_f32(input_parameters + 20); | |||
float32x4_t v6 = vld1q_f32(input_parameters + 24); | |||
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7, d8; | |||
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0); | |||
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4); | |||
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8); | |||
GI_FLOAT32_t v3 = GiLoadFloat32(input_parameters + 12); | |||
GI_FLOAT32_t v4 = GiLoadFloat32(input_parameters + 16); | |||
GI_FLOAT32_t v5 = GiLoadFloat32(input_parameters + 20); | |||
GI_FLOAT32_t v6 = GiLoadFloat32(input_parameters + 24); | |||
//! B | |||
//! 1.5 0 0 0 0 0 0 0 0 | |||
@@ -113,77 +113,77 @@ struct InputTransformF73_NCHW44 { | |||
// 5.0f, 10.0f, 5.75f, 2.75f, v5 | |||
// 4.25f, 1.75f, 2.0f, 0.0f, v6 | |||
#define cb(i) \ | |||
d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||
d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||
auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \ | |||
auto t##i##0 = d7; \ | |||
auto t##i##1 = d7; \ | |||
auto t##i##2 = d7; \ | |||
auto t##i##3 = d7; \ | |||
auto t##i##4 = d7; \ | |||
auto t##i##5 = d7; \ | |||
auto t##i##6 = d7; \ | |||
auto t##i##7 = d7; \ | |||
t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \ | |||
t##i##0 = t##i##0 - d1; \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \ | |||
t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \ | |||
t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \ | |||
t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \ | |||
t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \ | |||
t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \ | |||
t##i##7 = t##i##7 - d1; \ | |||
t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \ | |||
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \ | |||
t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \ | |||
t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \ | |||
t##i##8 = t##i##8 - d2; \ | |||
t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \ | |||
t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \ | |||
t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \ | |||
t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \ | |||
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \ | |||
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \ | |||
t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \ | |||
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \ | |||
t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \ | |||
t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \ | |||
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \ | |||
t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \ | |||
t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \ | |||
t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \ | |||
t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \ | |||
t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \ | |||
t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \ | |||
t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \ | |||
t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \ | |||
t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \ | |||
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \ | |||
t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \ | |||
t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \ | |||
t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \ | |||
t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \ | |||
t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \ | |||
t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \ | |||
t##i##5 = t##i##5 - d6; \ | |||
t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \ | |||
t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \ | |||
t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0); | |||
#define cb(i) \ | |||
d0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||
d7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||
auto t##i##8 = GiLoadFloat32(patchT + i * alpha * pack_size + 8 * pack_size); \ | |||
auto t##i##0 = d7; \ | |||
auto t##i##1 = d7; \ | |||
auto t##i##2 = d7; \ | |||
auto t##i##3 = d7; \ | |||
auto t##i##4 = d7; \ | |||
auto t##i##5 = d7; \ | |||
auto t##i##6 = d7; \ | |||
auto t##i##7 = d7; \ | |||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d7, v0, 0); \ | |||
t##i##0 = t##i##0 - d1; \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d1, v0, 0); \ | |||
t##i##2 = GiSimdFmaLane(t##i##2, d1, v0, 0); \ | |||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d1, v0, 1); \ | |||
t##i##4 = GiSimdFmaLane(t##i##4, d1, v0, 1); \ | |||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d1, v0, 2); \ | |||
t##i##6 = GiSimdFmaLane(t##i##6, d1, v0, 2); \ | |||
t##i##7 = t##i##7 - d1; \ | |||
t##i##8 = GiSimdFmaLane(t##i##8, d1, v0, 0); \ | |||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 3); \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d2, v1, 0); \ | |||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d2, v1, 1); \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d2, v1, 2); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d2, v1, 3); \ | |||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d2, v2, 0); \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d2, v2, 1); \ | |||
t##i##8 = t##i##8 - d2; \ | |||
t##i##0 = GiSimdFmaLane(t##i##0, d3, v2, 2); \ | |||
t##i##1 = GiSimdFmaLane(t##i##1, d3, v2, 3); \ | |||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d3, v3, 0); \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d3, v2, 0); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d3, v3, 1); \ | |||
t##i##5 = GiSimdFmaLane(t##i##5, d3, v3, 2); \ | |||
t##i##6 = GiSimdFmaLane(t##i##6, d3, v3, 3); \ | |||
t##i##7 = GiSimdFmaLane(t##i##7, d3, v2, 2); \ | |||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d3, v0, 3); \ | |||
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 3); \ | |||
t##i##1 = GiSimdFmaLane(t##i##1, d4, v4, 0); \ | |||
t##i##2 = GiSimdFmaLane(t##i##2, d4, v4, 1); \ | |||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v4, 2); \ | |||
t##i##4 = GiSimdFmaLane(t##i##4, d4, v4, 3); \ | |||
t##i##5 = GiSimdFmaLane(t##i##5, d4, v5, 0); \ | |||
t##i##6 = GiSimdFmaLane(t##i##6, d4, v5, 1); \ | |||
t##i##8 = GiSimdFmaLane(t##i##8, d4, v2, 2); \ | |||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d5, v2, 2); \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d5, v5, 2); \ | |||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d5, v5, 3); \ | |||
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d5, v6, 0); \ | |||
t##i##4 = GiSimdFmaLane(t##i##4, d5, v6, 1); \ | |||
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d5, v5, 2); \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v6, 0); \ | |||
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v2, 2); \ | |||
t##i##8 = GiSimdFmaLane(t##i##8, d5, v0, 3); \ | |||
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d6, v0, 0); \ | |||
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d6, v1, 0); \ | |||
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d6, v1, 1); \ | |||
t##i##3 = GiSimdFmaLane(t##i##3, d6, v1, 0); \ | |||
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d6, v3, 1); \ | |||
t##i##5 = t##i##5 - d6; \ | |||
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d6, v6, 2); \ | |||
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d6, v2, 2); \ | |||
t##i##0 = GiSimdFmaLane(t##i##0, d0, v0, 0); | |||
UNROLL_CALL_RAW(9, cb); | |||
#undef cb | |||
@@ -198,100 +198,100 @@ struct InputTransformF73_NCHW44 { | |||
d5 = t7##i; \ | |||
d6 = t7##i; \ | |||
d7 = t7##i; \ | |||
d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ | |||
d8 = GiFmsqLaneQFloat32(d8, t7##i, v0, 0); \ | |||
d0 = d0 - t1##i; \ | |||
d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ | |||
d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ | |||
d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ | |||
d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ | |||
d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ | |||
d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t1##i, v0, 0); \ | |||
d2 = GiSimdFmaLane(d2, t1##i, v0, 0); \ | |||
d3 = GiFmsqLaneQFloat32(d3, t1##i, v0, 1); \ | |||
d4 = GiSimdFmaLane(d4, t1##i, v0, 1); \ | |||
d5 = GiFmsqLaneQFloat32(d5, t1##i, v0, 2); \ | |||
d6 = GiSimdFmaLane(d6, t1##i, v0, 2); \ | |||
d7 = d7 - t1##i; \ | |||
d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ | |||
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ | |||
d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ | |||
d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ | |||
d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ | |||
d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ | |||
d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ | |||
d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ | |||
d8 = GiSimdFmaLane(d8, t1##i, v0, 0); \ | |||
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 3); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t2##i, v1, 0); \ | |||
d2 = GiFmsqLaneQFloat32(d2, t2##i, v1, 1); \ | |||
d3 = GiSimdFmaLane(d3, t2##i, v1, 2); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t2##i, v1, 3); \ | |||
d5 = GiFmsqLaneQFloat32(d5, t2##i, v2, 0); \ | |||
d6 = GiFmsqLaneQFloat32(d6, t2##i, v2, 1); \ | |||
d8 = d8 - t2##i; \ | |||
d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ | |||
d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ | |||
d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ | |||
d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ | |||
d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ | |||
d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ | |||
d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ | |||
d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ | |||
d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ | |||
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ | |||
d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ | |||
d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ | |||
d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ | |||
d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ | |||
d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ | |||
d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ | |||
d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ | |||
d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ | |||
d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ | |||
d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ | |||
d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ | |||
d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ | |||
d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ | |||
d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ | |||
d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ | |||
d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ | |||
d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ | |||
d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ | |||
d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ | |||
d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ | |||
d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ | |||
d0 = GiSimdFmaLane(d0, t3##i, v2, 2); \ | |||
d1 = GiSimdFmaLane(d1, t3##i, v2, 3); \ | |||
d2 = GiFmsqLaneQFloat32(d2, t3##i, v3, 0); \ | |||
d3 = GiSimdFmaLane(d3, t3##i, v2, 0); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t3##i, v3, 1); \ | |||
d5 = GiSimdFmaLane(d5, t3##i, v3, 2); \ | |||
d6 = GiSimdFmaLane(d6, t3##i, v3, 3); \ | |||
d7 = GiSimdFmaLane(d7, t3##i, v2, 2); \ | |||
d8 = GiFmsqLaneQFloat32(d8, t3##i, v0, 3); \ | |||
d0 = GiSimdFmaLane(d0, t4##i, v0, 3); \ | |||
d1 = GiSimdFmaLane(d1, t4##i, v4, 0); \ | |||
d2 = GiSimdFmaLane(d2, t4##i, v4, 1); \ | |||
d3 = GiFmsqLaneQFloat32(d3, t4##i, v4, 2); \ | |||
d4 = GiSimdFmaLane(d4, t4##i, v4, 3); \ | |||
d5 = GiSimdFmaLane(d5, t4##i, v5, 0); \ | |||
d6 = GiSimdFmaLane(d6, t4##i, v5, 1); \ | |||
d8 = GiSimdFmaLane(d8, t4##i, v2, 2); \ | |||
d0 = GiFmsqLaneQFloat32(d0, t5##i, v2, 2); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t5##i, v5, 2); \ | |||
d2 = GiFmsqLaneQFloat32(d2, t5##i, v5, 3); \ | |||
d3 = GiFmsqLaneQFloat32(d3, t5##i, v6, 0); \ | |||
d4 = GiSimdFmaLane(d4, t5##i, v6, 1); \ | |||
d5 = GiFmsqLaneQFloat32(d5, t5##i, v5, 2); \ | |||
d6 = GiFmsqLaneQFloat32(d6, t5##i, v6, 0); \ | |||
d7 = GiFmsqLaneQFloat32(d7, t5##i, v2, 2); \ | |||
d8 = GiSimdFmaLane(d8, t5##i, v0, 3); \ | |||
d0 = GiFmsqLaneQFloat32(d0, t6##i, v0, 0); \ | |||
d1 = GiFmsqLaneQFloat32(d1, t6##i, v1, 0); \ | |||
d2 = GiFmsqLaneQFloat32(d2, t6##i, v1, 1); \ | |||
d3 = GiSimdFmaLane(d3, t6##i, v1, 0); \ | |||
d4 = GiFmsqLaneQFloat32(d4, t6##i, v3, 1); \ | |||
d5 = d5 - t6##i; \ | |||
d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ | |||
d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ | |||
d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ | |||
vst1q_f32( \ | |||
d6 = GiFmsqLaneQFloat32(d6, t6##i, v6, 2); \ | |||
d8 = GiFmsqLaneQFloat32(d8, t6##i, v2, 2); \ | |||
d0 = GiSimdFmaLane(d0, t0##i, v0, 0); \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d0); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d1); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d2); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d3); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d4); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d5); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d6); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
d7); \ | |||
vst1q_f32( \ | |||
GiStoreFloat32( \ | |||
input_transform_buf + \ | |||
(8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ | |||
@@ -413,7 +413,7 @@ struct OutputTransformF73_NCHW44 { | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace fallback { | |||
namespace winograd { | |||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) | |||
@@ -554,14 +554,14 @@ void winograd_F73_mk4_f_nchw44::output( | |||
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, | |||
"NCHW44 Winograd filter transform requires OC is times of 4"); | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_arm_common_winograd_fp32_F73_mk4, cb, float, float, bmode, | |||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_fallback_winograd_fp32_F73_mk4, cb, float, float, bmode, | |||
nonline_mode); | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,413 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/intrinsic_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/unroll_macro.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "src/fallback/general_intrinsic/gi_int.h" | |||
namespace megdnn { | |||
namespace { | |||
struct Vld1qF32S { | |||
static GI_FORCEINLINE GI_FLOAT32_t impl(const float32_t* ptr) { | |||
return GiLoadFloat32(ptr); | |||
} | |||
}; | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wuninitialized" | |||
#ifdef __GNUC__ | |||
#ifndef __has_warning | |||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
#else | |||
#if __has_warning("-Wmaybe-uninitialized") | |||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" | |||
#endif | |||
#endif | |||
#endif | |||
template < | |||
int weight_number, int base_offset, int ptr_step, int oc_block, typename Func, | |||
typename T, typename T2, typename... XT> | |||
struct LoadHelper { | |||
static GI_FORCEINLINE void impl(T& weight, T2 ptr, int oc_offset, XT... args); | |||
}; | |||
#define WEIGHT_CB(step) \ | |||
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); | |||
#define LOAD_HELPER(step) \ | |||
template < \ | |||
int base_offset, int ptr_step, typename Func, typename T, typename T2, \ | |||
typename... XT> \ | |||
struct LoadHelper<step, base_offset, ptr_step, 0, Func, T, T2, XT...> { \ | |||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int, XT... args) { \ | |||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||
} \ | |||
} | |||
LOAD_HELPER(1); | |||
LOAD_HELPER(2); | |||
LOAD_HELPER(3); | |||
LOAD_HELPER(4); | |||
LOAD_HELPER(5); | |||
LOAD_HELPER(6); | |||
LOAD_HELPER(7); | |||
LOAD_HELPER(8); | |||
LOAD_HELPER(9); | |||
LOAD_HELPER(10); | |||
LOAD_HELPER(11); | |||
LOAD_HELPER(12); | |||
LOAD_HELPER(13); | |||
LOAD_HELPER(14); | |||
LOAD_HELPER(15); | |||
LOAD_HELPER(16); | |||
#undef LOAD_HELPER | |||
#undef WEIGHT_CB | |||
///////////////////////////c_dim = 1///////////////////////// | |||
#define WEIGHT_CB(step) src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); | |||
#define LOAD_HELPER(step) \ | |||
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \ | |||
struct LoadHelper<step, base_offset, ptr_step, 1, Func, T, T2> { \ | |||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int) { \ | |||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||
} \ | |||
} | |||
LOAD_HELPER(1); | |||
LOAD_HELPER(2); | |||
LOAD_HELPER(3); | |||
LOAD_HELPER(4); | |||
LOAD_HELPER(5); | |||
LOAD_HELPER(6); | |||
LOAD_HELPER(7); | |||
LOAD_HELPER(8); | |||
LOAD_HELPER(9); | |||
#undef LOAD_HELPER | |||
#undef WEIGHT_CB | |||
/////////////////////////c_dim = 2/////////////////////////////// | |||
#define WEIGHT_CB(step) \ | |||
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ | |||
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); | |||
#define LOAD_HELPER(step) \ | |||
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \ | |||
struct LoadHelper<step, base_offset, ptr_step, 2, Func, T, T2> { \ | |||
static GI_FORCEINLINE void impl(T& src, T2 ptr, int oc_offset) { \ | |||
UNROLL_CALL_RAW(step, WEIGHT_CB); \ | |||
} \ | |||
} | |||
LOAD_HELPER(1); | |||
LOAD_HELPER(2); | |||
LOAD_HELPER(3); | |||
LOAD_HELPER(4); | |||
LOAD_HELPER(5); | |||
LOAD_HELPER(6); | |||
LOAD_HELPER(7); | |||
LOAD_HELPER(8); | |||
#undef LOAD_HELPER | |||
#undef WEIGHT_CB | |||
template < | |||
int weight_number, int base_offset, int ptr_step, int c_dim, typename Func, | |||
typename T, typename T2> | |||
GI_FORCEINLINE void load_helper(T& weight, T2 ptr, int oc_offset) { | |||
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl( | |||
weight, ptr, oc_offset); | |||
} | |||
////////////////////Store_OCX_OW8_Remain///////////////////////// | |||
template <int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
} | |||
}; | |||
template <typename Op, typename T, typename T2, typename T3> | |||
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { | |||
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||
} | |||
}; | |||
template <int c_dim, int ow_remain, typename Op, typename T, typename T2> | |||
GI_FORCEINLINE void store_ocx_ow8_remain_static( | |||
T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, ld_dst_oc); | |||
} | |||
#undef cb | |||
#undef cb2 | |||
#undef cb_case | |||
#undef cb_case2 | |||
#pragma GCC diagnostic pop | |||
/////////////////////////init_ocx_ow8//////////////////// | |||
template <typename T> | |||
struct GiLdqSimd; | |||
template <> | |||
struct GiLdqSimd<float> { | |||
static constexpr int simd_len = 4; | |||
}; | |||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||
struct InitOcxOw8 { | |||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step); | |||
}; | |||
template <int c_dim, BiasMode bias_mode, typename T, typename T2> | |||
struct InitOcxOw8<c_dim, bias_mode, 0, T, T2> { | |||
static GI_FORCEINLINE void impl(T&, const T2*, int) {} | |||
}; | |||
#define BAIS_INIT_NO_BIAS_C2(step) \ | |||
c[0][step] = GiBroadcastFloat32(static_cast<T2>(0)); \ | |||
c[1][step] = GiBroadcastFloat32(static_cast<T2>(0)); | |||
#define BAIS_INIT_NO_BIAS_C1(step) c[0][step] = GiBroadcastFloat32(static_cast<T2>(0)); | |||
#define BAIS_INIT_BROADCAST_C2(step) \ | |||
c[0][step] = GiLoadFloat32(bias_ptr); \ | |||
c[1][step] = GiLoadFloat32(bias_ptr + oc_step); | |||
#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = GiLoadFloat32(bias_ptr); | |||
#define BAIS_INIT_BIAS_C2(step) \ | |||
c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); \ | |||
c[1][step] = GiLoadFloat32(bias_ptr + oc_step + step * simd_len); | |||
#define BAIS_INIT_BIAS_C1(step) c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); | |||
#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ | |||
template <typename T, typename T2> \ | |||
struct InitOcxOw8<cdim, BiasMode::NO_BIAS, ow_remain, T, T2> { \ | |||
static GI_FORCEINLINE void impl(T& c, const T2*, int) { \ | |||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ | |||
} \ | |||
}; \ | |||
template <typename T, typename T2> \ | |||
struct InitOcxOw8<cdim, BiasMode::BROADCAST_CHANNEL_BIAS, ow_remain, T, T2> { \ | |||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||
(void)oc_step; \ | |||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ | |||
} \ | |||
}; \ | |||
template <typename T, typename T2> \ | |||
struct InitOcxOw8<cdim, BiasMode::BIAS, ow_remain, T, T2> { \ | |||
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \ | |||
constexpr int simd_len = GiLdqSimd<T2>::simd_len; \ | |||
(void)oc_step; \ | |||
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ | |||
} \ | |||
}; | |||
#define INSTANCE_InitOcxOw8_C(ow_remain) \ | |||
INSTANCE_InitOcxOw8(ow_remain, 2); \ | |||
INSTANCE_InitOcxOw8(ow_remain, 1); | |||
INSTANCE_InitOcxOw8_C(1); | |||
INSTANCE_InitOcxOw8_C(2); | |||
INSTANCE_InitOcxOw8_C(3); | |||
INSTANCE_InitOcxOw8_C(4); | |||
INSTANCE_InitOcxOw8_C(5); | |||
INSTANCE_InitOcxOw8_C(6); | |||
INSTANCE_InitOcxOw8_C(7); | |||
INSTANCE_InitOcxOw8_C(8); | |||
#undef INSTANCE_InitOcxOw8 | |||
#undef INSTANCE_InitOcxOw8_C | |||
#undef BAIS_INIT_BIAS_C1 | |||
#undef BAIS_INIT_BIAS_C2 | |||
#undef BAIS_INIT_BROADCAST_C1 | |||
#undef BAIS_INIT_BROADCAST_C2 | |||
#undef BAIS_INIT_NO_BIAS_C1 | |||
#undef BAIS_INIT_NO_BIAS_C2 | |||
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2> | |||
GI_FORCEINLINE void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { | |||
InitOcxOw8<c_dim, bias_mode, ow_remain, T, T2>::impl(c, bias_ptr, oc_step); | |||
} | |||
} // namespace | |||
} // namespace megdnn | |||
#undef GI_FORCEINLINE | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,86 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/postprocess_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(fallback_gi_conv_bias_postprocess_helper) | |||
namespace { | |||
#define GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||
_midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ | |||
switch (_nonline_mode) { \ | |||
case param::ConvBias::NonlineMode::IDENTITY: { \ | |||
MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ | |||
cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
break; \ | |||
} \ | |||
case param::ConvBias::NonlineMode::RELU: { \ | |||
MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ | |||
cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
break; \ | |||
} \ | |||
case param::ConvBias::NonlineMode::SIGMOID: { \ | |||
MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ | |||
cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
break; \ | |||
} \ | |||
case param::ConvBias::NonlineMode::H_SWISH: { \ | |||
MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ | |||
cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
break; \ | |||
} \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
#define GI_DISPATCH_CONV_WINOGRAD_BIAS( \ | |||
_midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ | |||
switch (_bmode) { \ | |||
case BiasMode::BIAS: { \ | |||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||
_midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ | |||
_nonline_mode, __VA_ARGS__) \ | |||
break; \ | |||
} \ | |||
case BiasMode::NO_BIAS: { \ | |||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||
_midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \ | |||
_nonline_mode, __VA_ARGS__) \ | |||
break; \ | |||
} \ | |||
case BiasMode::BROADCAST_CHANNEL_BIAS: { \ | |||
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \ | |||
_midout_tag, cb, 2, _src_type, _dst_type, \ | |||
BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \ | |||
break; \ | |||
} \ | |||
default: \ | |||
megdnn_assert(0); \ | |||
break; \ | |||
} | |||
} // namespace |
@@ -0,0 +1,193 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/gi/utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstring> | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
template <typename ctype, size_t len> | |||
struct Vector; | |||
template <> | |||
struct Vector<float, 4> { | |||
GI_FLOAT32_t value; | |||
Vector() {} | |||
Vector(const float v) { value = GiBroadcastFloat32(v); } | |||
Vector(const Vector& lr) { value = lr.value; } | |||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||
Vector(const GI_FLOAT32_t& v) { value = v; } | |||
static Vector load(const float* addr) { | |||
Vector v; | |||
v.value = GiLoadFloat32(addr); | |||
return v; | |||
} | |||
static void save(float* addr, const Vector& v) { GiStoreFloat32(addr, v.value); } | |||
void save(float* addr) { save(addr, *this); } | |||
Vector operator+(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiAddFloat32(value, lr.value); | |||
return dst; | |||
} | |||
Vector& operator+=(const Vector& lr) { | |||
value = GiAddFloat32(value, lr.value); | |||
return *this; | |||
} | |||
Vector operator-(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiSubtractFloat32(value, lr.value); | |||
return dst; | |||
} | |||
Vector& operator-=(const Vector& lr) { | |||
value = GiSubtractFloat32(value, lr.value); | |||
return *this; | |||
} | |||
Vector operator*(float lr) { | |||
Vector dst; | |||
dst.value = GiMultiplyScalerFloat32(value, lr); | |||
return dst; | |||
} | |||
Vector operator*(const Vector& lr) { | |||
Vector dst; | |||
dst.value = GiMultiplyFloat32(value, lr.value); | |||
return dst; | |||
} | |||
Vector& operator*=(const Vector& lr) { | |||
value = GiMultiplyFloat32(value, lr.value); | |||
return *this; | |||
} | |||
Vector& operator=(const Vector& lr) { | |||
value = lr.value; | |||
return *this; | |||
} | |||
Vector& operator=(const Vector&& lr) { | |||
value = std::move(lr.value); | |||
return *this; | |||
} | |||
Vector operator-() { | |||
Vector dst; | |||
dst.value = -value; | |||
return dst; | |||
} | |||
}; | |||
template <> | |||
struct Vector<float, 8> { | |||
GI_FLOAT32_V2_t value; | |||
Vector() {} | |||
Vector(const float v) { | |||
value.val[0] = GiBroadcastFloat32(v); | |||
value.val[1] = GiBroadcastFloat32(v); | |||
} | |||
Vector(const Vector& lr) { value = lr.value; } | |||
Vector(const Vector&& lr) { value = std::move(lr.value); } | |||
Vector(const GI_FLOAT32_V2_t& v) { value = v; } | |||
static Vector load(const float* addr) { | |||
Vector v; | |||
#if defined(GI_TEST_NAIVE) | |||
v.value.val[0] = GiLoadFloat32(addr); | |||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||
#elif defined(__arm__) || defined(__aarch64__) | |||
v.value = vld1q_f32_x2(addr); | |||
#else | |||
v.value.val[0] = GiLoadFloat32(addr); | |||
v.value.val[1] = GiLoadFloat32(addr + 4); | |||
#endif | |||
return v; | |||
} | |||
static void save(float* addr, const Vector& v) { | |||
#if defined(GI_TEST_NAIVE) | |||
GiStoreFloat32(addr, v.value.val[0]); | |||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||
#elif defined(__arm__) || defined(__aarch64__) | |||
vst1q_f32_x2(addr, v.value); | |||
#else | |||
GiStoreFloat32(addr, v.value.val[0]); | |||
GiStoreFloat32(addr + 4, v.value.val[1]); | |||
#endif | |||
} | |||
void save(float* addr) { save(addr, *this); } | |||
Vector operator+(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||
dst.value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||
return dst; | |||
} | |||
Vector& operator+=(const Vector& lr) { | |||
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||
return *this; | |||
} | |||
Vector& add(const Vector& lr) { | |||
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]); | |||
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]); | |||
return *this; | |||
} | |||
Vector operator-(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); | |||
dst.value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); | |||
return dst; | |||
} | |||
Vector& operator-=(const Vector& lr) { | |||
value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]); | |||
value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]); | |||
return *this; | |||
} | |||
Vector operator*(float lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiMultiplyScalerFloat32(value.val[0], lr); | |||
dst.value.val[1] = GiMultiplyScalerFloat32(value.val[1], lr); | |||
return dst; | |||
} | |||
//! val + lr * n | |||
Vector& mla(const Vector& lr, float n) { | |||
value.val[0] = GiMultiplyAddScalarFloat32(value.val[0], lr.value.val[0], n); | |||
value.val[1] = GiMultiplyAddScalarFloat32(value.val[1], lr.value.val[1], n); | |||
return *this; | |||
} | |||
Vector operator*(const Vector& lr) { | |||
Vector dst; | |||
dst.value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); | |||
dst.value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); | |||
return dst; | |||
} | |||
Vector& operator*=(const Vector& lr) { | |||
value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]); | |||
value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]); | |||
return *this; | |||
} | |||
Vector& operator=(const Vector& lr) { | |||
value = lr.value; | |||
return *this; | |||
} | |||
Vector& operator=(const Vector&& lr) { | |||
value = std::move(lr.value); | |||
return *this; | |||
} | |||
Vector operator-() { | |||
Vector dst; | |||
dst.value.val[0] = -value.val[0]; | |||
dst.value.val[1] = -value.val[1]; | |||
return dst; | |||
} | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -16,6 +16,7 @@ | |||
#include "src/fallback/conv_bias/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||
#include "src/fallback/conv_bias/gi/fp32/algos.h" | |||
#include "src/fallback/conv_bias/im2col/algos.h" | |||
#include "src/fallback/convolution/opr_impl.h" | |||
#include "src/naive/convolution/algorithms.h" | |||
@@ -34,6 +35,14 @@ | |||
using namespace megdnn; | |||
using namespace fallback; | |||
namespace { | |||
//! TODO: imp is_fallback_exclude_gi_or_naive | |||
bool is_naive(const detail::Algorithm* algo) { | |||
return algo->handle_type() == Handle::HandleType::NAIVE; | |||
} | |||
} // anonymous namespace | |||
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | |||
switch (format) { | |||
case param::ConvBias::Format::NCHW44: | |||
@@ -73,16 +82,95 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_gi_winograd_algos; | |||
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44; | |||
AlgoF32DirectNCHW44 f32_direct_nchw44; | |||
AlgoF32Direct f32_direct; | |||
AlgoF32DirectStride2 f32_direct_stride2; | |||
AlgoF32DirectStride1 f32_direct_stride1; | |||
public: | |||
AlgoPack() { | |||
// fallback gi fp32 algo | |||
m_all_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||
m_all_algos.emplace_back(&f32_chanel_wise_nchw44); | |||
m_all_algos.emplace_back(&f32_direct_nchw44); | |||
m_all_algos.emplace_back(&f32_direct_stride1); | |||
m_all_algos.emplace_back(&f32_direct_stride2); | |||
m_all_algos.emplace_back(&f32_direct); | |||
static CpuOprDelegationStorage<2> storage; | |||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
using MatmulFormat = param::MatrixMul::Format; | |||
auto&& matmul_algos = | |||
static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4}); | |||
for (auto&& algo : matmul_algos) { | |||
if (is_naive(algo)) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
//! uncomment this when low precision mode is done | |||
#if 0 | |||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
#endif | |||
} | |||
} | |||
//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw | |||
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd | |||
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type( | |||
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); | |||
for (auto&& algo : matmul_algos) { | |||
if (is_naive(algo)) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
for (auto&& algo : m_gi_winograd_algos) { | |||
m_all_algos.emplace_back(algo); | |||
} | |||
// end fallback gi fp32 algo | |||
refhold.emplace_back(new AlgoConv1x1Gemv()); | |||
m_all_algos.emplace_back(refhold.back().get()); | |||
static CpuOprDelegationStorage<> storage; | |||
auto matmul_opr = storage.get<MatrixMul>(); | |||
auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
->get_all_packed_algo(); | |||
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr) | |||
->get_all_packed_algo(); | |||
for (auto&& algo : matmul_algos) { | |||
#if MEGDNN_X86 | |||
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | |||
@@ -226,6 +226,20 @@ public: | |||
FB_CONV1x1, | |||
FB_CONV1x1_GEMV, | |||
FB_IM2COL, | |||
GI_COMMON_WINOGRAD_F23_4X4_FP32, | |||
GI_COMMON_WINOGRAD_F63_FP32, | |||
GI_COMMON_WINOGRAD_F63_4X4_FP32, | |||
GI_COMMON_WINOGRAD_F54_FP32, | |||
GI_COMMON_WINOGRAD_F45_FP32, | |||
GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||
GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||
GI_COMMON_DIRECT_FP32, | |||
GI_COMMON_DIRECT_STRD1_FP32, | |||
GI_COMMON_DIRECT_STRD2_FP32, | |||
GI_COMMON_DIRECT_NCHW44_FP32, | |||
GI_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||
GI_COMMON_CHWNWISE_NCHW44_F32, | |||
#if MEGDNN_X86 | |||
X86_DIRECT = 1 << 8, | |||
@@ -248,20 +262,6 @@ public: | |||
ARM_COMMON_DIRECT_STRD1_FP16, | |||
ARM_COMMON_CHWNWISE_NCHW88_F16, | |||
ARM_COMMON_DIRECT_NCHW88_FP16, | |||
ARM_COMMON_WINOGRAD_F23_4X4_FP32, | |||
ARM_COMMON_WINOGRAD_F63_FP32, | |||
ARM_COMMON_WINOGRAD_F63_4X4_FP32, | |||
ARM_COMMON_WINOGRAD_F54_FP32, | |||
ARM_COMMON_WINOGRAD_F45_FP32, | |||
ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32, | |||
ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32, | |||
ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32, | |||
ARM_COMMON_DIRECT_FP32, | |||
ARM_COMMON_DIRECT_STRD1_FP32, | |||
ARM_COMMON_DIRECT_STRD2_FP32, | |||
ARM_COMMON_DIRECT_NCHW44_FP32, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_FP32, | |||
ARM_COMMON_CHWNWISE_NCHW44_F32, | |||
ARM_COMMON_DIRECT_STRD1_S8, | |||
ARM_COMMON_DIRECT_STRD2_S8, | |||
ARM_COMMON_DIRECT_NCHW44, | |||
@@ -383,6 +383,23 @@ private: | |||
class AlgoWinogradF32_4x4; | |||
class AlgoWinogradQS8; | |||
class AlgoWinogradQS8_8x8; | |||
class AlgoFP32WinogradF23_4x4; | |||
class AlgoFP32WinogradF63; | |||
class AlgoFP32WinogradF63_4x4; | |||
class AlgoFP32WinogradF54; | |||
class AlgoFP32WinogradF45; | |||
class AlgoFP32WinogradF23_4x4_NCHW44; | |||
class AlgoFP32WinogradF63_4x4_NCHW44; | |||
class AlgoFP32WinogradF73_4x4_NCHW44; | |||
class AlgoF32Direct; | |||
class AlgoF32DirectStride1; | |||
class AlgoF32DirectStride2; | |||
class AlgoF32DirectNCHWNCHW44; | |||
class AlgoF32ChannelWiseNCHW44; | |||
class AlgoF32DirectNCHW44; | |||
class AlgoPack; | |||
NCBKernSizeParam m_prev_selected_algo_sizep; | |||
@@ -81,23 +81,6 @@ TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { | |||
} | |||
} | |||
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
#define CONV_BIAS_MATMUL_QU8_MODE(MODE) \ | |||
using namespace conv_bias; \ | |||
std::vector<TestArg> args = get_quantized_args_with_nlmode(MODE); \ | |||
@@ -1015,14 +998,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_4x4) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); | |||
#else | |||
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); | |||
@@ -1031,14 +1006,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_4x4) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); | |||
#else | |||
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4); | |||
@@ -1212,30 +1179,10 @@ void benchmark_winograd_nchw_vs_nchw44( | |||
} | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); | |||
#else | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); | |||
#else | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"AARCH64_F32_MK4_4x16:4:6", "ARM_COMMON_F32_GEMV_MK4:4:7", handle()); | |||
"AARCH64_F32_MK4_4x16:4:6", "FB_GI_F32_GEMV_MK4:4:7", handle()); | |||
#else | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:7", handle()); | |||
@@ -1609,156 +1556,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) { | |||
computations / used0, used1, computations / used1, used1 / used0); | |||
} | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
using namespace conv_bias; | |||
param::ConvBias param; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.nonlineMode = NonlineMode::RELU; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
constexpr size_t RUN = 50; | |||
Benchmarker<ConvBias> benchmark0(handle()); | |||
benchmark0.set_display(false); | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_display(false); | |||
benchmark1.set_param(param); | |||
benchmark1.set_times(RUN); | |||
benchmark1.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||
TensorLayout dst_layout; | |||
opr->deduce_layout( | |||
{{1, group * 4, h, w}, dtype::Int8()}, | |||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
auto used0 = benchmark0.exec( | |||
{{1, group * 4, h, w}, | |||
{group * 4, 1, 1, kernel, kernel}, | |||
{1, group * 4, 1, 1}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
auto used1 = benchmark1.exec( | |||
{{1, group, h, w, 4}, | |||
{group, 1, 1, kernel, kernel, 4}, | |||
{1, group, 1, 1, 4}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||
"nchw44: " | |||
"%f ms %f GFlops " | |||
"speedup: %f\n", | |||
group, h, w, kernel, used0, computations / used0, used1, | |||
computations / used1, used0 / used1); | |||
}; | |||
for (size_t group : {8, 16, 32, 64}) { | |||
for (size_t kerenl : {2, 3, 5}) { | |||
run(group, 112, 112, kerenl); | |||
run(group, 56, 56, kerenl); | |||
run(group, 48, 48, kerenl); | |||
run(group, 28, 28, kerenl); | |||
run(group, 14, 14, kerenl); | |||
} | |||
} | |||
run(8, 112, 112, 3); | |||
run(32, 56, 56, 3); | |||
run(64, 28, 28, 3); | |||
run(128, 14, 14, 3); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
using namespace conv_bias; | |||
param::ConvBias param; | |||
param.stride_h = 2; | |||
param.stride_w = 2; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.nonlineMode = NonlineMode::RELU; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
constexpr size_t RUN = 50; | |||
Benchmarker<ConvBias> benchmark0(handle()); | |||
benchmark0.set_display(false); | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_display(false); | |||
benchmark1.set_param(param); | |||
benchmark1.set_times(RUN); | |||
benchmark1.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||
TensorLayout dst_layout; | |||
opr->deduce_layout( | |||
{{1, group * 4, h, w}, dtype::Int8()}, | |||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
auto used0 = benchmark0.exec( | |||
{{1, group * 4, h, w}, | |||
{group * 4, 1, 1, kernel, kernel}, | |||
{1, group * 4, 1, 1}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
auto used1 = benchmark1.exec( | |||
{{1, group, h, w, 4}, | |||
{group, 1, 1, kernel, kernel, 4}, | |||
{1, group, 1, 1, 4}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||
"nchw44: " | |||
"%f ms %f GFlops " | |||
"speedup: %f\n", | |||
group, h, w, kernel, used0, computations / used0, used1, | |||
computations / used1, used0 / used1); | |||
}; | |||
for (size_t group : {8, 16, 32, 64}) { | |||
for (size_t kerenl : {2, 3, 5}) { | |||
run(group, 112, 112, kerenl); | |||
run(group, 56, 56, kerenl); | |||
run(group, 48, 48, kerenl); | |||
run(group, 28, 28, kerenl); | |||
run(group, 14, 14, kerenl); | |||
} | |||
} | |||
run(8, 112, 112, 3); | |||
run(32, 56, 56, 3); | |||
run(64, 28, 28, 3); | |||
run(128, 14, 14, 3); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
using namespace conv_bias; | |||
@@ -303,84 +303,6 @@ void checker_conv_bias_int8x8x32_multi( | |||
} | |||
} | |||
/**********************************F32 direct************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) { | |||
check_conv_bias( | |||
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(), | |||
"F32DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||
//! k=7 s=1 | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), handle(), | |||
"F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) { | |||
check_conv_bias( | |||
get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(), | |||
"F32STRD1"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) { | |||
check_conv_bias( | |||
get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(), | |||
"F32STRD2"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) { | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, | |||
true), | |||
handle(), "F32_CONV_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) { | |||
check_conv_bias( | |||
get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, | |||
true), | |||
handle(), "F32_CONV_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
/**********************************F16 direct************************/ | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) { | |||
@@ -787,50 +709,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = | |||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd( | |||
"4:2:32", checker, args, param::MatrixMul::Format::MK4, | |||
param::ConvBias::Format::NCHW44); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(3); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("1:6:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = | |||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd( | |||
"4:6:16", checker, args, param::MatrixMul::Format::MK4, | |||
param::ConvBias::Format::NCHW44); | |||
} | |||
//! uncomment it when low precision mode is ok | |||
#if 0 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | |||
@@ -853,22 +731,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCE | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(4); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("1:5:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(5); | |||
Checker<ConvBiasForward> checker(handle()); | |||
check_winograd("1:4:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { | |||
using namespace conv_bias; | |||
@@ -81,207 +81,6 @@ void benchmark_impl( | |||
} | |||
} // namespace | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, H, W}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, H, W}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 2; | |||
param.stride_w = 2; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group, size_t P, size_t S) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) { | |||
constexpr size_t RUNS = 50; | |||
@@ -20,91 +20,7 @@ | |||
using namespace megdnn; | |||
using namespace test; | |||
using namespace conv_bias; | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = | |||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd( | |||
"4:2:32", checker, args, param::MatrixMul::Format::MK4, | |||
param::ConvBias::Format::NCHW44); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(3); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("1:6:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_mk_packed_args(); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = | |||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd( | |||
"4:6:16", checker, args, param::MatrixMul::Format::MK4, | |||
param::ConvBias::Format::NCHW44); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(4); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("1:5:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_winograd_args(5); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
handle()); | |||
check_winograd("1:4:32", checker, args); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> nchw44_args = | |||
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward> checker(handle()); | |||
auto run = [&checker]( | |||
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype, | |||
DType C_dtype, DType D_dtype, const float eps) { | |||
for (auto&& arg : args) { | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
}; | |||
//! uncomment this when low precision mode is ok | |||
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), | |||
// dtype::Float32(), dtype::Float32(), 1e-2f); | |||
//! remove this when low precision mode is ok | |||
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), | |||
dtype::Float32(), 1e-3f); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
@@ -286,30 +286,6 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, FP32_GEMV_MK4) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4")); | |||
checker.set_epsilon(1e-2); | |||
auto run = [&](size_t M, size_t K) { | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||
B = TensorShape{K / 4, 1, 4}; | |||
checker.set_param(param).execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {4, 16, 128, 1024}) | |||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||
run(M, K); | |||
} | |||
TEST_F(ARM_COMMON, MATRIX_MUL_RECORD) { | |||
TaskRecordChecker<MatrixMul> checker(0); | |||
checker.set_epsilon(1e-2); | |||
@@ -117,6 +117,30 @@ TEST_F(FALLBACK, CONV_BIAS_FORWARD_RECORD) { | |||
} | |||
} | |||
TEST_F(FALLBACK, FP32_GEMV_MK4_GI) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_GI_F32_GEMV_MK4")); | |||
checker.set_epsilon(1e-2); | |||
auto run = [&](size_t M, size_t K) { | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||
B = TensorShape{K / 4, 1, 4}; | |||
checker.set_param(param).execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {4, 16, 128, 1024}) | |||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||
run(M, K); | |||
} | |||
std::vector<conv_bias::TestArg> get_conv_bias_args( | |||
std::vector<size_t> kernel, std::vector<size_t> padv, | |||
std::vector<param::ConvBias::NonlineMode> nlmodev, std::vector<size_t> stridev, | |||
@@ -257,6 +281,189 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | |||
dtype::Float32{}, dtype::Float32{}, "FALLBACK_NAIVE"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S2) { | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false, | |||
true), | |||
handle(), "F32_CONV_NCHW_NCHW44"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S1) { | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false, | |||
true), | |||
handle(), "F32_CONV_NCHW_NCHW44"); | |||
} | |||
std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||
bool no_full_bias) { | |||
using namespace conv_bias; | |||
using Param = param::ConvBias; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<TestArg> args; | |||
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, | |||
size_t stride, NLMode nlmode, bool pad) { | |||
Param param; | |||
param.stride_h = stride; | |||
param.stride_w = stride; | |||
if (pad) { | |||
param.pad_h = kernel / 2; | |||
param.pad_w = kernel / 2; | |||
} else { | |||
param.pad_h = 0; | |||
param.pad_w = 0; | |||
} | |||
param.nonlineMode = nlmode; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w, 4}, | |||
TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{}); | |||
if (!no_bias) { | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w, 4}, | |||
TensorShape{group, 1, 1, kernel, kernel, 4}, | |||
TensorShape{1, group, 1, 1, 4}); | |||
} | |||
if (!no_full_bias) { | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w, 4}, | |||
TensorShape{group, 1, 1, kernel, kernel, 4}, | |||
TensorShape{ | |||
n, group, (h + 2 * param.pad_w - kernel) / stride + 1, | |||
(w + 2 * param.pad_w - kernel) / stride + 1, 4}); | |||
} | |||
}; | |||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||
if (!no_nonlinemode) { | |||
nonlinemode.emplace_back(NLMode::RELU); | |||
nonlinemode.emplace_back(NLMode::H_SWISH); | |||
} | |||
for (size_t n : {1, 2}) { | |||
for (auto nlmode : nonlinemode) { | |||
for (bool pad : {true}) { | |||
for (size_t group : {1, 2, 4, 7, 16}) { | |||
for (size_t size : {4, 6, 7, 9, 20}) { | |||
for (size_t kern : kernel) { | |||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||
} | |||
} | |||
} | |||
} | |||
for (bool pad : {false}) { | |||
for (size_t group : {1, 2, 7, 16}) { | |||
for (size_t size : {7, 9, 20}) { | |||
for (size_t kern : kernel) { | |||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
return args; | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE2_FP32_NCHW44) { | |||
check_conv_bias( | |||
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(), | |||
"F32_CHANNEL_WISE_NCHW44"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K7) { | |||
//! k=7 s=1 | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args( | |||
{7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K2K3) { | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args( | |||
{2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K5) { | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S2) { | |||
check_conv_bias( | |||
conv_bias::get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2), | |||
handle(), "F32_CONV_NCHW44_DIRECT"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32) { | |||
check_conv_bias( | |||
conv_bias::get_conv_bias_args( | |||
{1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), | |||
handle(), "F32DIRECT"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR2) { | |||
check_conv_bias( | |||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | |||
handle(), "F32STRD2"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR1) { | |||
check_conv_bias( | |||
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | |||
handle(), "F32STRD1"); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_PREPROCESS_NCHW44) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> nchw44_args = conv_bias::get_nchw44_conv_bias_args( | |||
{3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
Checker<ConvBiasForward> checker(handle()); | |||
auto run = [&checker]( | |||
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype, | |||
DType C_dtype, DType D_dtype, const float eps) { | |||
for (auto&& arg : args) { | |||
checker.set_dtype(0, A_dtype) | |||
.set_dtype(1, B_dtype) | |||
.set_dtype(2, C_dtype) | |||
.set_dtype(4, D_dtype) | |||
.set_epsilon(eps) | |||
.set_param(arg.param) | |||
.execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
}; | |||
//! uncomment this when low precision mode is ok | |||
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), | |||
// dtype::Float32(), dtype::Float32(), 1e-2f); | |||
//! remove this when low precision mode is ok | |||
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(), | |||
dtype::Float32(), 1e-3f); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { | |||
using namespace conv_bias; | |||
param::ConvBias cur_param; | |||
@@ -273,6 +480,422 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) { | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
namespace { | |||
void benchmark_impl( | |||
const param::ConvBias param, | |||
std::vector<std::pair<SmallVector<TensorShape>, float>>& shapes_and_computation, | |||
const std::string algo_name, size_t RUNS, | |||
TaskExecutorConfig&& multi_thread_config, | |||
TaskExecutorConfig&& single_thread_config, std::vector<DType>& data_type) { | |||
std::vector<float> multi_thread_times, single_thread_times; | |||
{ | |||
auto multi_thread_hanle = create_cpu_handle(0, true, &multi_thread_config); | |||
auto benchmarker = Benchmarker<ConvBias>(multi_thread_hanle.get()); | |||
benchmarker.set_times(RUNS) | |||
.set_display(false) | |||
.set_param(param) | |||
.set_dtype(0, data_type[0]) | |||
.set_dtype(1, data_type[1]) | |||
.set_dtype(2, data_type[2]) | |||
.set_dtype(4, data_type[3]) | |||
.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||
for (auto shape : shapes_and_computation) { | |||
multi_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); | |||
} | |||
} | |||
{ | |||
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); | |||
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | |||
benchmarker.set_times(RUNS) | |||
.set_display(false) | |||
.set_param(param) | |||
.set_dtype(0, data_type[0]) | |||
.set_dtype(1, data_type[1]) | |||
.set_dtype(2, data_type[2]) | |||
.set_dtype(4, data_type[3]) | |||
.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||
for (auto shape : shapes_and_computation) { | |||
single_thread_times.push_back(benchmarker.exec(shape.first) / RUNS); | |||
} | |||
} | |||
printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread); | |||
printf("core_ids:"); | |||
for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) { | |||
printf("%zu ", multi_thread_config.affinity_core_set[i]); | |||
} | |||
printf(", Single thread core_id %zu\n", single_thread_config.affinity_core_set[0]); | |||
for (size_t i = 0; i < shapes_and_computation.size(); i++) { | |||
auto shapes = shapes_and_computation[i]; | |||
printf("Bench case: "); | |||
for (auto&& shape : shapes.first) { | |||
printf("%s ", shape.to_string().c_str()); | |||
} | |||
float computations = shapes.second; | |||
printf("%zu threads gflops: %f,\n single thread gflops: " | |||
"%f. spead up = %f, speedup/cores=%f\n", | |||
multi_thread_config.nr_thread, computations / multi_thread_times[i], | |||
computations / single_thread_times[i], | |||
single_thread_times[i] / multi_thread_times[i], | |||
single_thread_times[i] / multi_thread_times[i] / | |||
multi_thread_config.nr_thread); | |||
} | |||
} | |||
} // namespace | |||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, H, W}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32DIRECT"; | |||
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR1) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, H, W}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4); | |||
bench_case(1, 32, 32, 200, 200, 3, 32); | |||
bench_case(1, 32, 32, 128, 128, 3, 4); | |||
bench_case(1, 32, 32, 128, 128, 3, 32); | |||
bench_case(1, 32, 32, 100, 100, 3, 4); | |||
bench_case(1, 32, 32, 100, 100, 3, 32); | |||
bench_case(1, 32, 32, 80, 80, 3, 4); | |||
bench_case(1, 32, 32, 80, 80, 3, 32); | |||
std::string algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD1"; | |||
printf("Benchmark F32STRD1_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1); | |||
bench_case(1, 32, 32, 128, 128, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR2) { | |||
constexpr size_t RUNS = 50; | |||
param::ConvBias param; | |||
param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.stride_h = 2; | |||
param.stride_w = 2; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group, size_t P, size_t S) { | |||
SmallVector<TensorShape> shapes{ | |||
{N, IC, H, W}, | |||
{group, OC / group, IC / group, FS, FS}, | |||
{1, OC, 1, 1}, | |||
{}, | |||
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}}; | |||
TensorShape dst{N, OC, H, W}; | |||
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + | |||
dst.total_nr_elems()) * | |||
1e-6; | |||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||
}; | |||
bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2); | |||
std::string algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_LARGE_GROUP algo\n"); | |||
std::vector<DType> data_type = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
shapes_and_computation.clear(); | |||
algo_name = "F32STRD2"; | |||
printf("Benchmark F32STRD2_SMALL_GROUP algo\n"); | |||
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2); | |||
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}}, | |||
data_type); | |||
benchmark_impl( | |||
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, | |||
data_type); | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE1_NCHW44) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
using namespace conv_bias; | |||
param::ConvBias param; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.nonlineMode = NonlineMode::RELU; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
constexpr size_t RUN = 50; | |||
Benchmarker<ConvBias> benchmark0(handle()); | |||
benchmark0.set_display(false); | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_display(false); | |||
benchmark1.set_param(param); | |||
benchmark1.set_times(RUN); | |||
benchmark1.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||
TensorLayout dst_layout; | |||
opr->deduce_layout( | |||
{{1, group * 4, h, w}, dtype::Int8()}, | |||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
auto used0 = benchmark0.exec( | |||
{{1, group * 4, h, w}, | |||
{group * 4, 1, 1, kernel, kernel}, | |||
{1, group * 4, 1, 1}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
auto used1 = benchmark1.exec( | |||
{{1, group, h, w, 4}, | |||
{group, 1, 1, kernel, kernel, 4}, | |||
{1, group, 1, 1, 4}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||
"nchw44: " | |||
"%f ms %f GFlops " | |||
"speedup: %f\n", | |||
group, h, w, kernel, used0, computations / used0, used1, | |||
computations / used1, used0 / used1); | |||
}; | |||
for (size_t group : {8, 16, 32, 64}) { | |||
for (size_t kerenl : {2, 3, 5}) { | |||
run(group, 112, 112, kerenl); | |||
run(group, 56, 56, kerenl); | |||
run(group, 48, 48, kerenl); | |||
run(group, 28, 28, kerenl); | |||
run(group, 14, 14, kerenl); | |||
} | |||
} | |||
run(8, 112, 112, 3); | |||
run(32, 56, 56, 3); | |||
run(64, 28, 28, 3); | |||
run(128, 14, 14, 3); | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE2_NCHW44) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
using namespace conv_bias; | |||
param::ConvBias param; | |||
param.stride_h = 2; | |||
param.stride_w = 2; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.nonlineMode = NonlineMode::RELU; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
constexpr size_t RUN = 50; | |||
Benchmarker<ConvBias> benchmark0(handle()); | |||
benchmark0.set_display(false); | |||
benchmark0.set_param(param); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2")); | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_display(false); | |||
benchmark1.set_param(param); | |||
benchmark1.set_times(RUN); | |||
benchmark1.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44")); | |||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) { | |||
TensorLayout dst_layout; | |||
opr->deduce_layout( | |||
{{1, group * 4, h, w}, dtype::Int8()}, | |||
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()}, | |||
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
auto used0 = benchmark0.exec( | |||
{{1, group * 4, h, w}, | |||
{group * 4, 1, 1, kernel, kernel}, | |||
{1, group * 4, 1, 1}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
auto used1 = benchmark1.exec( | |||
{{1, group, h, w, 4}, | |||
{group, 1, 1, kernel, kernel, 4}, | |||
{1, group, 1, 1, 4}, | |||
{}, | |||
{}}) / | |||
RUN; | |||
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops " | |||
"nchw44: " | |||
"%f ms %f GFlops " | |||
"speedup: %f\n", | |||
group, h, w, kernel, used0, computations / used0, used1, | |||
computations / used1, used0 / used1); | |||
}; | |||
for (size_t group : {8, 16, 32, 64}) { | |||
for (size_t kerenl : {2, 3, 5}) { | |||
run(group, 112, 112, kerenl); | |||
run(group, 56, 56, kerenl); | |||
run(group, 48, 48, kerenl); | |||
run(group, 28, 28, kerenl); | |||
run(group, 14, 14, kerenl); | |||
} | |||
} | |||
run(8, 112, 112, 3); | |||
run(32, 56, 56, 3); | |||
run(64, 28, 28, 3); | |||
run(128, 14, 14, 3); | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { | |||
constexpr size_t RUNS = 10; | |||
param::ConvBias param; | |||
@@ -320,6 +943,164 @@ TEST_F(FALLBACK, BENCHMARK_CONVBIAS) { | |||
} | |||
} | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_4x4) { | |||
#if MEGDNN_AARCH64 | |||
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4); | |||
#elif MEGDNN_ARMV7 | |||
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4); | |||
#else | |||
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:2", handle(), 3, 4); | |||
#endif | |||
} | |||
void benchmark_winograd_nchw_vs_nchw44( | |||
const char* algo_name0, const char* algo_name1, Handle* handle) { | |||
using namespace conv_bias; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<conv_bias::TestArg> args_nchw44; | |||
std::vector<conv_bias::TestArg> args_nchw; | |||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t group, | |||
NLMode nlmode) { | |||
param::ConvBias param; | |||
param.format = param::ConvBias::Format::NCHW44; | |||
param.stride_h = 1; | |||
param.stride_w = 1; | |||
param.pad_h = 1; | |||
param.pad_w = 1; | |||
param.nonlineMode = nlmode; | |||
if (group == 1) { | |||
param.sparse = param::ConvBias::Sparse::DENSE; | |||
args_nchw44.emplace_back( | |||
param, TensorShape{n, ic / 4, h, w, 4}, | |||
TensorShape{oc / 4, ic / 4, 3, 3, 4, 4}, TensorShape{}); | |||
param.format = param::ConvBias::Format::NCHW; | |||
args_nchw.emplace_back( | |||
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, 3, 3}, | |||
TensorShape{}); | |||
} else { | |||
auto oc_per_group = oc / group; | |||
auto ic_per_group = ic / group; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
args_nchw44.emplace_back( | |||
param, TensorShape{n, ic_per_group / 4, h, w, 4}, | |||
TensorShape{group, oc_per_group / 4, ic_per_group / 4, 3, 3, 4, 4}, | |||
TensorShape{}); | |||
param.format = param::ConvBias::Format::NCHW; | |||
args_nchw.emplace_back( | |||
param, TensorShape{n, ic, h, w}, | |||
TensorShape{group, oc_per_group, ic_per_group, 3, 3}, | |||
TensorShape{}); | |||
} | |||
}; | |||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||
for (auto nlmode : nonlinemode) | |||
for (size_t n : {1}) | |||
for (size_t group = 1; group <= 1; ++group) { | |||
pack(n, 512, 512, 15, 15, group, nlmode); | |||
pack(n, 512, 256, 15, 15, group, nlmode); | |||
pack(n, 256, 256, 29, 29, group, nlmode); | |||
pack(n, 256, 128, 29, 29, group, nlmode); | |||
pack(n, 128, 128, 57, 57, group, nlmode); | |||
pack(n, 128, 64, 57, 57, group, nlmode); | |||
pack(n, 24, 24, 224, 224, group, nlmode); | |||
pack(n, 64, 24, 123, 123, group, nlmode); | |||
pack(n, 64, 64, 56, 56, group, nlmode); | |||
pack(n, 128, 128, 28, 28, group, nlmode); | |||
pack(n, 256, 256, 14, 14, group, nlmode); | |||
pack(n, 512, 512, 7, 7, group, nlmode); | |||
} | |||
using namespace conv_bias; | |||
constexpr size_t RUN = 10; | |||
Benchmarker<ConvBias> benchmark_winograd_nchw(handle); | |||
benchmark_winograd_nchw.set_display(false); | |||
benchmark_winograd_nchw.set_times(RUN); | |||
Benchmarker<ConvBias> benchmark_winograd_nchw44(handle); | |||
benchmark_winograd_nchw44.set_display(false); | |||
benchmark_winograd_nchw44.set_times(RUN); | |||
std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0); | |||
std::string winograd_nchw44_algo_name = ssprintf("WINOGRAD_NCHW44:%s", algo_name1); | |||
for (size_t i = 0; i < args_nchw.size(); ++i) { | |||
auto arg_nchw = args_nchw[i]; | |||
auto arg_nchw44 = args_nchw44[i]; | |||
TensorLayout dst_layout; | |||
auto opr = handle->create_operator<ConvBias>(); | |||
opr->param() = arg_nchw.param; | |||
opr->deduce_layout( | |||
{arg_nchw.src, dtype::Float32()}, {arg_nchw.filter, dtype::Float32()}, | |||
{arg_nchw.bias, dtype::Float32()}, {}, dst_layout); | |||
//! dst.nr_elems * IC * FH * FW * 2 | |||
float computations = dst_layout.total_nr_elems() * arg_nchw.filter[1] * | |||
arg_nchw.filter[2] * arg_nchw.filter[3] * 2.0 / | |||
(1024 * 1024 * 1024) * 1e3; | |||
benchmark_winograd_nchw.set_param(arg_nchw.param); | |||
auto nchw_used = algo_benchmark<ConvBias>( | |||
benchmark_winograd_nchw, | |||
{arg_nchw.src, arg_nchw.filter, {}, {}, {}}, | |||
winograd_nchw_algo_name.c_str()) / | |||
RUN; | |||
benchmark_winograd_nchw44.set_param(arg_nchw44.param); | |||
auto nchw44_used = algo_benchmark<ConvBias>( | |||
benchmark_winograd_nchw44, | |||
{arg_nchw44.src, arg_nchw44.filter, {}, {}, {}}, | |||
winograd_nchw44_algo_name.c_str()) / | |||
RUN; | |||
printf("%s %s: nchw: %f ms %f Gflops nchw44: %f ms %f GFlops " | |||
"speedup: " | |||
"%f\n", | |||
arg_nchw.src.to_string().c_str(), arg_nchw.filter.to_string().c_str(), | |||
nchw_used, computations / nchw_used, nchw44_used, | |||
computations / nchw44_used, nchw_used / nchw44_used); | |||
} | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle()); | |||
#elif MEGDNN_ARMV7 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle()); | |||
#else | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"FB_GI_F32_MK4_4x8:4:2", "FB_GI_F32_MK4_4x8:4:2", handle()); | |||
#endif | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_4x4) { | |||
#if MEGDNN_AARCH64 | |||
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4); | |||
#elif MEGDNN_ARMV7 | |||
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4); | |||
#else | |||
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:6", handle(), 3, 4); | |||
#endif | |||
} | |||
TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { | |||
#if MEGDNN_AARCH64 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle()); | |||
#elif MEGDNN_ARMV7 | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle()); | |||
#else | |||
benchmark_winograd_nchw_vs_nchw44( | |||
"FB_GI_F32_MK4_4x8:4:6", "FB_GI_F32_MK4_4x8:4:6", handle()); | |||
#endif | |||
} | |||
#endif | |||
} // namespace test | |||