Browse Source

feat(fallback): move arm_common f32 convbias to fallback gi

GitOrigin-RevId: ccf8b589be
release-1.10
Megvii Engine Team 3 years ago
parent
commit
e4cc85e52c
99 changed files with 4502 additions and 3431 deletions
  1. +0
    -725
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp
  2. +0
    -512
      dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp
  3. +0
    -54
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  4. +0
    -15
      dnn/src/arm_common/conv_bias/opr_impl.h
  5. +2
    -0
      dnn/src/fallback/conv_bias/direct/multi_thread_common.h
  6. +37
    -0
      dnn/src/fallback/conv_bias/gi/block_helper.h
  7. +37
    -38
      dnn/src/fallback/conv_bias/gi/fp32/algos.cpp
  8. +18
    -18
      dnn/src/fallback/conv_bias/gi/fp32/algos.h
  9. +73
    -80
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
  10. +4
    -4
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h
  11. +50
    -56
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
  12. +4
    -4
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h
  13. +5
    -5
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp
  14. +162
    -152
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp
  15. +4
    -4
      dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h
  16. +264
    -264
      dnn/src/fallback/conv_bias/gi/fp32/direct.cpp
  17. +3
    -3
      dnn/src/fallback/conv_bias/gi/fp32/direct.h
  18. +51
    -49
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
  19. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp
  20. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
  21. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp
  22. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp
  23. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
  24. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp
  25. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp
  26. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
  27. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp
  28. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp
  29. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
  30. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp
  31. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp
  32. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
  33. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp
  34. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp
  35. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
  36. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp
  37. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp
  38. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
  39. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp
  40. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp
  41. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
  42. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp
  43. +67
    -68
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  44. +83
    -84
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  45. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp
  46. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
  47. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp
  48. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp
  49. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
  50. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp
  51. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp
  52. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
  53. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp
  54. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp
  55. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
  56. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp
  57. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp
  58. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
  59. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp
  60. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp
  61. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
  62. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp
  63. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp
  64. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
  65. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp
  66. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp
  67. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
  68. +2
    -2
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp
  69. +51
    -57
      dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  70. +723
    -0
      dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp
  71. +3
    -3
      dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h
  72. +503
    -0
      dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp
  73. +3
    -3
      dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h
  74. +9
    -9
      dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp
  75. +4
    -4
      dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h
  76. +9
    -9
      dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp
  77. +8
    -9
      dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h
  78. +8
    -10
      dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h
  79. +196
    -0
      dnn/src/fallback/conv_bias/gi/fp32/helper.h
  80. +6
    -5
      dnn/src/fallback/conv_bias/gi/fp32/strategy.h
  81. +48
    -49
      dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp
  82. +25
    -24
      dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp
  83. +23
    -23
      dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp
  84. +31
    -30
      dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp
  85. +18
    -19
      dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp
  86. +14
    -14
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp
  87. +108
    -108
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp
  88. +155
    -155
      dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp
  89. +413
    -0
      dnn/src/fallback/conv_bias/gi/intrinsic_helper.h
  90. +86
    -0
      dnn/src/fallback/conv_bias/gi/postprocess_helper.h
  91. +193
    -0
      dnn/src/fallback/conv_bias/gi/utils.h
  92. +92
    -4
      dnn/src/fallback/conv_bias/opr_impl.cpp
  93. +31
    -14
      dnn/src/fallback/conv_bias/opr_impl.h
  94. +1
    -204
      dnn/test/arm_common/conv_bias.cpp
  95. +0
    -138
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  96. +0
    -201
      dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
  97. +0
    -84
      dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp
  98. +0
    -24
      dnn/test/arm_common/matrix_mul.cpp
  99. +781
    -0
      dnn/test/fallback/conv_bias.cpp

+ 0
- 725
dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp View File

@@ -1,725 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <algorithm>

#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/arm_common/simd_macro/neon_helper.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs1)

using namespace megdnn;
using namespace arm_common;
using namespace fp32;
using namespace conv_stride1;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;

void conv_stride1::do_conv_2x2_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;
//! unroll of 2
size_t ic = 0;
for (; ic + 1 < IC; ic += 2) {
const float* src_ptr = src + IW * IH * ic;
const float* src_ptr1 = src_ptr + IW * IH;
float* outptr = dst;

const float* r00 = src_ptr;
const float* r01 = src_ptr + IW;
const float* r10 = src_ptr1;
const float* r11 = src_ptr1 + IW;

const float* k0 = filter + ic * 4;
const float* k1 = k0 + 4;

MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_LOADU(k1);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _r000 = MEGDNN_SIMD_LOADU(r00);
MEGDNN_SIMD_TYPE _r010 = MEGDNN_SIMD_LOADU(r01);
MEGDNN_SIMD_TYPE _r001 = MEGDNN_SIMD_LOADU(r00 + 1);
MEGDNN_SIMD_TYPE _r011 = MEGDNN_SIMD_LOADU(r01 + 1);

MEGDNN_SIMD_TYPE _r100 = MEGDNN_SIMD_LOADU(r10);
MEGDNN_SIMD_TYPE _r110 = MEGDNN_SIMD_LOADU(r11);
MEGDNN_SIMD_TYPE _r101 = MEGDNN_SIMD_LOADU(r10 + 1);
MEGDNN_SIMD_TYPE _r111 = MEGDNN_SIMD_LOADU(r11 + 1);

MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, MEGDNN_SIMD_GET_LOW(_k0), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, MEGDNN_SIMD_GET_LOW(_k0), 1);
_sum = MEGDNN_SIMD_VMLAQ_LANE(
_sum, _r010, MEGDNN_SIMD_GET_HIGH(_k0), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(
_sum, _r011, MEGDNN_SIMD_GET_HIGH(_k0), 1);

_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, MEGDNN_SIMD_GET_LOW(_k1), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, MEGDNN_SIMD_GET_LOW(_k1), 1);
_sum = MEGDNN_SIMD_VMLAQ_LANE(
_sum, _r110, MEGDNN_SIMD_GET_HIGH(_k1), 0);
_sum = MEGDNN_SIMD_VMLAQ_LANE(
_sum, _r111, MEGDNN_SIMD_GET_HIGH(_k1), 1);

MEGDNN_SIMD_STOREU(outptr, _sum);

r00 += 4;
r01 += 4;
r10 += 4;
r11 += 4;
outptr += 4;
}

r00 += tail_step;
r01 += tail_step;
r10 += tail_step;
r11 += tail_step;
}
}
for (; ic < IC; ic++) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter + ic * 4;

MEGDNN_SIMD_TYPE _k0 = MEGDNN_SIMD_SET1(k0[0]);
MEGDNN_SIMD_TYPE _k1 = MEGDNN_SIMD_SET1(k0[1]);
MEGDNN_SIMD_TYPE _k2 = MEGDNN_SIMD_SET1(k0[2]);
MEGDNN_SIMD_TYPE _k3 = MEGDNN_SIMD_SET1(k0[3]);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_LOADU(r0 + 1);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_LOADU(r1 + 1);

MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2;

_sum = MEGDNN_SIMD_FMADD(_r00, _k0, _sum);
_sum2 = MEGDNN_SIMD_MUL(_r01, _k1);
_sum = MEGDNN_SIMD_FMADD(_r10, _k2, _sum);
_sum2 = MEGDNN_SIMD_FMADD(_r11, _k3, _sum2);

_sum = MEGDNN_SIMD_ADD(_sum, _sum2);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}
}
}

void conv_stride1::do_conv_3x3_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f);
MEGDNN_SIMD_TYPE _sum3 = MEGDNN_SIMD_LOADU(outptr2);
MEGDNN_SIMD_TYPE _sum4 = MEGDNN_SIMD_SET1(0.f);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2);

_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2);

_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r10, _k0123, 0);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r11, _k0123, 1);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r12, _k0123, 2);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r20, _k3456, 0);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r21, _k3456, 1);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r22, _k3456, 2);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r30, _k6789, 0);
_sum4 = MEGDNN_SIMD_FMA_LANE(_sum4, _r31, _k6789, 1);
_sum3 = MEGDNN_SIMD_FMA_LANE(_sum3, _r32, _k6789, 2);

_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2);
_sum3 = MEGDNN_SIMD_ADD(_sum3, _sum4);

MEGDNN_SIMD_STOREU(outptr, _sum1);
MEGDNN_SIMD_STOREU(outptr2, _sum3);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum1 = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_SET1(0.f);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r00n, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r00n, 2);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r10n, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r10n, 2);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r20n, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2);

_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r00, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r01, _k0123, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r02, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k3456, 0);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r11, _k3456, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k3456, 2);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r20, _k6789, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k6789, 1);
_sum1 = MEGDNN_SIMD_FMA_LANE(_sum1, _r22, _k6789, 2);

_sum1 = MEGDNN_SIMD_ADD(_sum1, _sum2);

MEGDNN_SIMD_STOREU(outptr, _sum1);

r0 += 4;
r1 += 4;
r2 += 4;
outptr += 4;
}
r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride1::do_conv_5x5_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4);
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8);
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12);
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16);
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20);
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);
MEGDNN_SIMD_TYPE _sum2 = MEGDNN_SIMD_LOADU(outptr2);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2);
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);

MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4);
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1);
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r10, _k0123, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r11, _k0123, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r12, _k0123, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r13, _k0123, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r14, _k4567, 0);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r20, _k4567, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r21, _k4567, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r22, _k4567, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r23, _k891011, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r24, _k891011, 1);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r30, _k891011, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r31, _k891011, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r32, _k12131415, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r33, _k12131415, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r34, _k12131415, 2);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r40, _k12131415, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r41, _k16171819, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r42, _k16171819, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r43, _k16171819, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r44, _k16171819, 3);

_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r50, _k20212223, 0);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r51, _k20212223, 1);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r52, _k20212223, 2);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r53, _k20212223, 3);
_sum2 = MEGDNN_SIMD_FMA_LANE(_sum2, _r54, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);
MEGDNN_SIMD_STOREU(outptr2, _sum2);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;
r4 += tail_step + IW;
r5 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0);
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4);
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1);
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2);
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride1::do_conv_7x7_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int width = OW >> 2;

rep(i, width) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4);

MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3
MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7
MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11
MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4
MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5
MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6
MEGDNN_SIMD_TYPE _r05 = MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8
MEGDNN_SIMD_TYPE _r06 = MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2);

MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4);

MEGDNN_SIMD_TYPE _r10 = MEGDNN_SIMD_LOADU(r1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_LOADU(r1 + 4);
MEGDNN_SIMD_TYPE _r10n = MEGDNN_SIMD_LOADU(r1 + 8);
MEGDNN_SIMD_TYPE _r11 = MEGDNN_SIMD_EXT(_r10, _r14, 1);
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r14, 2);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r10, _r14, 3);
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r14, _r10n, 1);
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r14, _r10n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2);

MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4);

MEGDNN_SIMD_TYPE _r20 = MEGDNN_SIMD_LOADU(r2);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_LOADU(r2 + 4);
MEGDNN_SIMD_TYPE _r20n = MEGDNN_SIMD_LOADU(r2 + 8);
MEGDNN_SIMD_TYPE _r21 = MEGDNN_SIMD_EXT(_r20, _r24, 1);
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r24, 2);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r20, _r24, 3);
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r24, _r20n, 1);
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r24, _r20n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2);

MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3);
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4);

MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_LOADU(r3 + 4);
MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 8);
MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r34, 1);
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r34, 2);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r30, _r34, 3);
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r34, _r30n, 1);
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r34, _r30n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2);

MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4);
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4);

MEGDNN_SIMD_TYPE _r40 = MEGDNN_SIMD_LOADU(r4);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_LOADU(r4 + 4);
MEGDNN_SIMD_TYPE _r40n = MEGDNN_SIMD_LOADU(r4 + 8);
MEGDNN_SIMD_TYPE _r41 = MEGDNN_SIMD_EXT(_r40, _r44, 1);
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r44, 2);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r40, _r44, 3);
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r44, _r40n, 1);
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r44, _r40n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2);

MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5);
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4);

MEGDNN_SIMD_TYPE _r50 = MEGDNN_SIMD_LOADU(r5);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_LOADU(r5 + 4);
MEGDNN_SIMD_TYPE _r50n = MEGDNN_SIMD_LOADU(r5 + 8);
MEGDNN_SIMD_TYPE _r51 = MEGDNN_SIMD_EXT(_r50, _r54, 1);
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r54, 2);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r50, _r54, 3);
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r54, _r50n, 1);
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r54, _r50n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2);

MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6);
MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4);

MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6);
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4);
MEGDNN_SIMD_TYPE _r60n = MEGDNN_SIMD_LOADU(r6 + 8);
MEGDNN_SIMD_TYPE _r61 = MEGDNN_SIMD_EXT(_r60, _r64, 1);
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r64, 2);
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r60, _r64, 3);
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r64, _r60n, 1);
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r64, _r60n, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k46474849, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k46474849, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k46474849, 2);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
r6 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}

#include "src/common/simd_macro/epilogue.h"
// vim: syntax=cpp.doxygen

+ 0
- 512
dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp View File

@@ -1,512 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <algorithm>

#include "./do_conv_stride2.h"
#include "midout.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/arm_common/simd_macro/neon_helper.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2)

using namespace megdnn;
using namespace arm_common;
using namespace fp32;
using namespace conv_stride2;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;

void conv_stride2::do_conv_2x2_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0);

MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1);

MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1);

MEGDNN_SIMD_TYPE _r10 = _r1.val[0];
MEGDNN_SIMD_TYPE _r11 = _r1.val[1];

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k0123, 2);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k0123, 3);

MEGDNN_SIMD_STOREU(outptr, _outp);

r0 += 8;
r1 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}

filter += 4;
}
}

void conv_stride2::do_conv_3x3_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k3456 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k5678 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k6789 = MEGDNN_SIMD_EXT(_k5678, _k5678, 1);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _outp = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r0 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r0n = MEGDNN_SIMD_LOAD2(r0 + 8);

MEGDNN_SIMD_TYPE _r00 = _r0.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r0.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0n.val[0], 1); // 2 4 6 8

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r00, _k0123, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r01, _k0123, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r02, _k0123, 2);

MEGDNN_SIMD_TYPE2 _r1 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r1n = MEGDNN_SIMD_LOAD2(r1 + 8);

MEGDNN_SIMD_TYPE _r10 = _r1.val[0];
MEGDNN_SIMD_TYPE _r11 = _r1.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1n.val[0], 1);

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r10, _k3456, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r11, _k3456, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r12, _k3456, 2);

MEGDNN_SIMD_TYPE2 _r2 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r2n = MEGDNN_SIMD_LOAD2(r2 + 8);

MEGDNN_SIMD_TYPE _r20 = _r2.val[0];
MEGDNN_SIMD_TYPE _r21 = _r2.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2n.val[0], 1);

_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r20, _k6789, 0);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r21, _k6789, 1);
_outp = MEGDNN_SIMD_FMA_LANE(_outp, _r22, _k6789, 2);

MEGDNN_SIMD_STOREU(outptr, _outp);

r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride2::do_conv_5x5_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(filter);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(filter + 4);
MEGDNN_SIMD_TYPE _k891011 = MEGDNN_SIMD_LOADU(filter + 8);
MEGDNN_SIMD_TYPE _k12131415 = MEGDNN_SIMD_LOADU(filter + 12);
MEGDNN_SIMD_TYPE _k16171819 = MEGDNN_SIMD_LOADU(filter + 16);
MEGDNN_SIMD_TYPE _k20212223 = MEGDNN_SIMD_LOADU(filter + 20);
MEGDNN_SIMD_TYPE _k24242424 = MEGDNN_SIMD_SET1(filter[24]);

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8);
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8
MEGDNN_SIMD_TYPE _r03 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9
MEGDNN_SIMD_TYPE _r04 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10

MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8);
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0];
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1];
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0];
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2);

MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8);
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0];
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1];
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0];
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2);

MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3);
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8);
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0];
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1];
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0];
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1];
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2);

MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4);
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8);
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0];
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1];
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0];
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1];
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k4567, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k4567, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k891011, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k891011, 1);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k891011, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k891011, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k12131415, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k12131415, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k12131415, 2);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k12131415, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k16171819, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k16171819, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k16171819, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k16171819, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k20212223, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k20212223, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k20212223, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k20212223, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k24242424, 0);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride2::do_conv_7x7_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr);

MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0);
MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4);

MEGDNN_SIMD_TYPE2 _r00_02461357 = MEGDNN_SIMD_LOAD2(r0);
MEGDNN_SIMD_TYPE2 _r00nx2 = MEGDNN_SIMD_LOAD2(r0 + 8);
MEGDNN_SIMD_TYPE _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
MEGDNN_SIMD_TYPE _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
MEGDNN_SIMD_TYPE _r00 = _r00_02461357.val[0]; // 0 2 4 6
MEGDNN_SIMD_TYPE _r01 = _r00_02461357.val[1]; // 1 3 5 7
MEGDNN_SIMD_TYPE _r02 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 1); // 2 4 6 8
MEGDNN_SIMD_TYPE _r03 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 1); // 3 5 7 9
MEGDNN_SIMD_TYPE _r04 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 2); // 4 6 8 10
MEGDNN_SIMD_TYPE _r05 =
MEGDNN_SIMD_EXT(_r01, _r0_9111315, 2); // 5 7 9 11
MEGDNN_SIMD_TYPE _r06 =
MEGDNN_SIMD_EXT(_r00, _r0_8101214, 3); // 6 8 10 12

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r02, _k0123, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r03, _k0123, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r04, _k4567, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r05, _k4567, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r06, _k4567, 2);

MEGDNN_SIMD_TYPE _k78910 = MEGDNN_SIMD_LOADU(k1);
MEGDNN_SIMD_TYPE _k11121314 = MEGDNN_SIMD_LOADU(k1 + 4);

MEGDNN_SIMD_TYPE2 _r10_02461357 = MEGDNN_SIMD_LOAD2(r1);
MEGDNN_SIMD_TYPE2 _r10nx2 = MEGDNN_SIMD_LOAD2(r1 + 8);
MEGDNN_SIMD_TYPE _r1_8101214 = _r10nx2.val[0];
MEGDNN_SIMD_TYPE _r1_9111315 = _r10nx2.val[1];
MEGDNN_SIMD_TYPE _r10 = _r10_02461357.val[0];
MEGDNN_SIMD_TYPE _r11 = _r10_02461357.val[1];
MEGDNN_SIMD_TYPE _r12 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 1);
MEGDNN_SIMD_TYPE _r13 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 1);
MEGDNN_SIMD_TYPE _r14 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 2);
MEGDNN_SIMD_TYPE _r15 = MEGDNN_SIMD_EXT(_r11, _r1_9111315, 2);
MEGDNN_SIMD_TYPE _r16 = MEGDNN_SIMD_EXT(_r10, _r1_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r10, _k78910, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r11, _k78910, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r12, _k78910, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r13, _k78910, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r14, _k11121314, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r15, _k11121314, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r16, _k11121314, 2);

MEGDNN_SIMD_TYPE _k14151617 = MEGDNN_SIMD_LOADU(k2);
MEGDNN_SIMD_TYPE _k18192021 = MEGDNN_SIMD_LOADU(k2 + 4);

MEGDNN_SIMD_TYPE2 _r20_02461357 = MEGDNN_SIMD_LOAD2(r2);
MEGDNN_SIMD_TYPE2 _r20nx2 = MEGDNN_SIMD_LOAD2(r2 + 8);
MEGDNN_SIMD_TYPE _r2_8101214 = _r20nx2.val[0];
MEGDNN_SIMD_TYPE _r2_9111315 = _r20nx2.val[1];
MEGDNN_SIMD_TYPE _r20 = _r20_02461357.val[0];
MEGDNN_SIMD_TYPE _r21 = _r20_02461357.val[1];
MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 1);
MEGDNN_SIMD_TYPE _r23 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 1);
MEGDNN_SIMD_TYPE _r24 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 2);
MEGDNN_SIMD_TYPE _r25 = MEGDNN_SIMD_EXT(_r21, _r2_9111315, 2);
MEGDNN_SIMD_TYPE _r26 = MEGDNN_SIMD_EXT(_r20, _r2_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r20, _k14151617, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r21, _k14151617, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r22, _k14151617, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r23, _k14151617, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r24, _k18192021, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r25, _k18192021, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r26, _k18192021, 2);

MEGDNN_SIMD_TYPE _k21222324 = MEGDNN_SIMD_LOADU(k3);
MEGDNN_SIMD_TYPE _k25262728 = MEGDNN_SIMD_LOADU(k3 + 4);

MEGDNN_SIMD_TYPE2 _r30_02461357 = MEGDNN_SIMD_LOAD2(r3);
MEGDNN_SIMD_TYPE2 _r30nx2 = MEGDNN_SIMD_LOAD2(r3 + 8);
MEGDNN_SIMD_TYPE _r3_8101214 = _r30nx2.val[0];
MEGDNN_SIMD_TYPE _r3_9111315 = _r30nx2.val[1];
MEGDNN_SIMD_TYPE _r30 = _r30_02461357.val[0];
MEGDNN_SIMD_TYPE _r31 = _r30_02461357.val[1];
MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 1);
MEGDNN_SIMD_TYPE _r33 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 1);
MEGDNN_SIMD_TYPE _r34 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 2);
MEGDNN_SIMD_TYPE _r35 = MEGDNN_SIMD_EXT(_r31, _r3_9111315, 2);
MEGDNN_SIMD_TYPE _r36 = MEGDNN_SIMD_EXT(_r30, _r3_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r30, _k21222324, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r31, _k21222324, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r32, _k21222324, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r33, _k21222324, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r34, _k25262728, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r35, _k25262728, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r36, _k25262728, 2);

MEGDNN_SIMD_TYPE _k28293031 = MEGDNN_SIMD_LOADU(k4);
MEGDNN_SIMD_TYPE _k32333435 = MEGDNN_SIMD_LOADU(k4 + 4);

MEGDNN_SIMD_TYPE2 _r40_02461357 = MEGDNN_SIMD_LOAD2(r4);
MEGDNN_SIMD_TYPE2 _r40nx2 = MEGDNN_SIMD_LOAD2(r4 + 8);
MEGDNN_SIMD_TYPE _r4_8101214 = _r40nx2.val[0];
MEGDNN_SIMD_TYPE _r4_9111315 = _r40nx2.val[1];
MEGDNN_SIMD_TYPE _r40 = _r40_02461357.val[0];
MEGDNN_SIMD_TYPE _r41 = _r40_02461357.val[1];
MEGDNN_SIMD_TYPE _r42 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 1);
MEGDNN_SIMD_TYPE _r43 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 1);
MEGDNN_SIMD_TYPE _r44 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 2);
MEGDNN_SIMD_TYPE _r45 = MEGDNN_SIMD_EXT(_r41, _r4_9111315, 2);
MEGDNN_SIMD_TYPE _r46 = MEGDNN_SIMD_EXT(_r40, _r4_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r40, _k28293031, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r41, _k28293031, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r42, _k28293031, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r43, _k28293031, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r44, _k32333435, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r45, _k32333435, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r46, _k32333435, 2);

MEGDNN_SIMD_TYPE _k35363738 = MEGDNN_SIMD_LOADU(k5);
MEGDNN_SIMD_TYPE _k39404142 = MEGDNN_SIMD_LOADU(k5 + 4);

MEGDNN_SIMD_TYPE2 _r50_02461357 = MEGDNN_SIMD_LOAD2(r5);
MEGDNN_SIMD_TYPE2 _r50nx2 = MEGDNN_SIMD_LOAD2(r5 + 8);
MEGDNN_SIMD_TYPE _r5_8101214 = _r50nx2.val[0];
MEGDNN_SIMD_TYPE _r5_9111315 = _r50nx2.val[1];
MEGDNN_SIMD_TYPE _r50 = _r50_02461357.val[0];
MEGDNN_SIMD_TYPE _r51 = _r50_02461357.val[1];
MEGDNN_SIMD_TYPE _r52 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 1);
MEGDNN_SIMD_TYPE _r53 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 1);
MEGDNN_SIMD_TYPE _r54 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 2);
MEGDNN_SIMD_TYPE _r55 = MEGDNN_SIMD_EXT(_r51, _r5_9111315, 2);
MEGDNN_SIMD_TYPE _r56 = MEGDNN_SIMD_EXT(_r50, _r5_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r50, _k35363738, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r51, _k35363738, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r52, _k35363738, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r53, _k35363738, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r54, _k39404142, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r55, _k39404142, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2);

MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6);
MEGDNN_SIMD_TYPE _k45464748 = MEGDNN_SIMD_LOADU(k6 + 3);

MEGDNN_SIMD_TYPE2 _r60_02461357 = MEGDNN_SIMD_LOAD2(r6);
MEGDNN_SIMD_TYPE2 _r60nx2 = MEGDNN_SIMD_LOAD2(r6 + 8);
MEGDNN_SIMD_TYPE _r6_8101214 = _r60nx2.val[0];
MEGDNN_SIMD_TYPE _r6_9111315 = _r60nx2.val[1];
MEGDNN_SIMD_TYPE _r60 = _r60_02461357.val[0];
MEGDNN_SIMD_TYPE _r61 = _r60_02461357.val[1];
MEGDNN_SIMD_TYPE _r62 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 1);
MEGDNN_SIMD_TYPE _r63 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 1);
MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 2);
MEGDNN_SIMD_TYPE _r65 = MEGDNN_SIMD_EXT(_r61, _r6_9111315, 2);
MEGDNN_SIMD_TYPE _r66 = MEGDNN_SIMD_EXT(_r60, _r6_8101214, 3);

_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r60, _k42434445, 0);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r61, _k42434445, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r62, _k42434445, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r63, _k42434445, 3);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r64, _k45464748, 1);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r65, _k45464748, 2);
_sum = MEGDNN_SIMD_FMA_LANE(_sum, _r66, _k45464748, 3);

MEGDNN_SIMD_STOREU(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
r5 += 8;
r6 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}
// vim: syntax=cpp.doxygen

+ 0
- 54
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -28,7 +28,6 @@

#include "include/megdnn/oprs/nn.h"
#include "src/arm_common/conv_bias/f16/algos.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/int8/stride1.h"
#include "src/arm_common/conv_bias/int8/stride2.h"
#include "src/arm_common/conv_bias/quint8/stride1.h"
@@ -69,14 +68,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
#endif

AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;

AlgoF32Direct f32_direct;
AlgoF32DirectStride2 f32_direct_stride2;
AlgoF32DirectStride1 f32_direct_stride1;

AlgoI8x8x16Direct i8x8x16_direct;
AlgoI8x8x16Stride2 i8x8x16_stride2;
AlgoI8x8x16Stride2Filter2 i8x8x16_stride2_filter2;
@@ -127,14 +118,6 @@ public:
m_direct_algos.emplace_back(&i8x8x16_stride2);
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44);

m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44);
m_direct_algos.emplace_back(&f32_direct_nchw44);

m_direct_algos.emplace_back(&f32_direct_stride1);
m_direct_algos.emplace_back(&f32_direct_stride2);
m_direct_algos.emplace_back(&f32_direct);

static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>();
using MatmulFormat = param::MatrixMul::Format;
@@ -145,22 +128,6 @@ public:
if (is_fallback_or_naive(algo))
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
@@ -175,27 +142,6 @@ public:
m_winograd_algos.emplace_back(refhold.back().get());
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (is_fallback_or_naive(algo))
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_winograd_algos.emplace_back(refhold.back().get());
}
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)


+ 0
- 15
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -49,15 +49,6 @@ private:
class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2;
class AlgoFP32WinogradF23_4x4;
class AlgoFP32WinogradF63;
class AlgoFP32WinogradF63_4x4;
class AlgoFP32WinogradF54;
class AlgoFP32WinogradF45;

class AlgoFP32WinogradF23_4x4_NCHW44;
class AlgoFP32WinogradF63_4x4_NCHW44;
class AlgoFP32WinogradF73_4x4_NCHW44;

class AlgoS8ChanWiseStride1NCHW44;
class AlgoS8ChanWiseStride2NCHW44;
@@ -78,12 +69,6 @@ private:

class AlgoDotS8Direct_NCHW44;
#endif
class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectStride2;
class AlgoF32DirectNCHWNCHW44;
class AlgoF32ChannelWiseNCHW44;
class AlgoF32DirectNCHW44;

class AlgoI8x8x16Direct;
class AlgoI8x8x16Stride2;


+ 2
- 0
dnn/src/fallback/conv_bias/direct/multi_thread_common.h View File

@@ -10,6 +10,8 @@
*/
#pragma once

#include "megbrain_build_config.h"

#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"



+ 37
- 0
dnn/src/fallback/conv_bias/gi/block_helper.h View File

@@ -0,0 +1,37 @@
/**
* \file dnn/src/fallback/conv_bias/gi/block_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/common/utils.h"
namespace megdnn {
namespace {
// block_helper is used to calculate oh block size
static inline int l2_block_helper(
const int nthread, const int amount, const int size_per_unit) {
//! TODO: opt config or dynamic config l2_cache_size for different ARCH
constexpr int l2_cache_size = 256 * 1024;
const int block_per_thread = div_ceil(amount, nthread);
const int best_block =
std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit);
const int max_block_num = div_ceil(block_per_thread, best_block);
const int min_block_num = std::max(max_block_num - 1, 1);
const int max_block = div_ceil(block_per_thread, max_block_num);
const int min_block = div_ceil(block_per_thread, min_block_num);
const int max_loss = std::abs(max_block_num * max_block - block_per_thread);
const int min_loss = std::abs(min_block_num * min_block - block_per_thread);
int block = max_loss > min_loss ? min_block : max_block;
return block;
}

} // namespace
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/algos.cpp → dnn/src/fallback/conv_bias/gi/fp32/algos.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/algos.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,23 +10,22 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/direct.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride1.h"
#include "src/arm_common/conv_bias/fp32/do_conv_stride2.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/conv_bias/img2col_helper.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/gi/fp32/algos.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/direct/multi_thread_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct.h"
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h"
#include "src/fallback/conv_bias/gi/fp32/do_conv_stride2.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/postprocess_helper.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_winograd_fp32)
MIDOUT_DECL(megdnn_fallback_winograd_fp32)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

/* ======================= AlgoFP32WinogradF23_4x4 ======================== */

@@ -34,10 +33,10 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 0, 0) {
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 0, 0) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_2x3_4x4_f;
using Strategy = winograd::winograd_gi_2x3_4x4_f;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
@@ -62,8 +61,8 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable(
}

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF23_4x4, winograd::winograd_2x3_4x4_f,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4);
AlgoFP32WinogradF23_4x4, winograd::winograd_gi_2x3_4x4_f,
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4);

/* ======================= AlgoFP32WinogradF63 ======================== */

@@ -71,7 +70,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) {
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 1, 0) {
using Strategy = winograd::winograd_6x3_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
@@ -95,7 +94,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT);

/* ======================= AlgoFP32WinogradF54 ======================== */

@@ -103,7 +102,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) {
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 2, 0) {
using Strategy = winograd::winograd_5x4_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
@@ -127,7 +126,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT);

/* ======================= AlgoFP32WinogradF45 ======================== */

@@ -135,7 +134,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) {
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 3, 0) {
using Strategy = winograd::winograd_4x5_1x1_f;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
@@ -159,7 +158,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT);

/* ======================= AlgoFP32WinogradF63_4x4 ======================== */

@@ -167,7 +166,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 4, 0) {
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 4, 0) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_6x3_4x4_f;
@@ -197,7 +196,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4);

/* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */

@@ -206,7 +205,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable(
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(
megdnn_arm_common_winograd_fp32,
megdnn_fallback_winograd_fp32,
midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
@@ -236,7 +235,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4);

/* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */

@@ -245,7 +244,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable(
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(param);
MIDOUT_BEGIN(
megdnn_arm_common_winograd_fp32,
megdnn_fallback_winograd_fp32,
midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
@@ -276,7 +275,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4);

/* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */

@@ -284,7 +283,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(
megdnn_arm_common_winograd_fp32,
megdnn_fallback_winograd_fp32,
midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
@@ -314,14 +313,14 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable(

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(
AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44,
megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4);
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::MK4);

/* ===================== direct algo ===================== */
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl);
MIDOUT_DECL(megdnn_fallback_conv_bias_f32_kimpl);

bool ConvBiasImpl::AlgoF32Direct::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto SH = fm.stride[0], SW = fm.stride[1];
@@ -341,7 +340,7 @@ bool ConvBiasImpl::AlgoF32Direct::usable(
return false;
}
size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto wbundle = fallback::MultithreadDirectConvCommon<float, float>::get_bundle(
param, large_group);
@@ -426,7 +425,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::get_kimpls(

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 0, 1) {
return get_kimpls(param);
}
MIDOUT_END();
@@ -435,7 +434,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32Direct::dispatch_kerns(
/* ===================== stride-1 algo ===================== */
bool ConvBiasImpl::AlgoF32DirectStride1::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
@@ -452,7 +451,7 @@ bool ConvBiasImpl::AlgoF32DirectStride1::usable(

size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle =
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
@@ -548,7 +547,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::get_kimpl

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 1, 2) {
return get_kimpls(param);
}
MIDOUT_END();
@@ -559,7 +558,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride1::dispatch_

bool ConvBiasImpl::AlgoF32DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
@@ -575,7 +574,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2::usable(
}
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 1) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle =
fallback::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
@@ -670,7 +669,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpl

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) {
MIDOUT_BEGIN(megdnn_fallback_conv_bias_f32_kimpl, 2, 2) {
return get_kimpls(param);
}
MIDOUT_END();

dnn/src/arm_common/conv_bias/fp32/algos.h → dnn/src/fallback/conv_bias/gi/fp32/algos.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/algos.h
* \file dnn/src/fallback/conv_bias/gi/fp32/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,11 +12,11 @@

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/matrix_mul/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase {
public:
AlgoFP32WinogradF23_4x4(
@@ -31,7 +31,7 @@ public:
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_FP32)
};

class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
@@ -50,7 +50,7 @@ public:
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32)
};

class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
@@ -67,7 +67,7 @@ public:
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_FP32)
};

class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
@@ -86,7 +86,7 @@ public:
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F54_FP32)
};

class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
@@ -105,7 +105,7 @@ public:
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F45_FP32)
};

//===================== NCHW44 Winograd Support =====================//
@@ -124,7 +124,7 @@ public:
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32)
};

class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase {
@@ -142,7 +142,7 @@ public:
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32)
};

class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
@@ -160,7 +160,7 @@ public:
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32)
};
// ================================================================= //

@@ -180,7 +180,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_FP32)
};

class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
@@ -199,7 +199,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD1_FP32)
};

class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
@@ -218,7 +218,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_STRD2_FP32)
};

class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
@@ -238,7 +238,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW44_FP32)
};

class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
@@ -258,7 +258,7 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_DIRECT_NCHW_NCHW44_FP32)
};

class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
@@ -277,10 +277,10 @@ public:
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32)
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_CHWNWISE_NCHW44_F32)
};

} // namespace arm_common
} // namespace fallback
} // namespace megdnn

#undef MEGDNN_WINOGRAD_ALGO_FUN_DECLARE

dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,29 +10,22 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#pragma GCC diagnostic ignored "-Wunused-parameter"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif

template <int shift>
static inline void shift_src(float32x4_t rsrc[3][4]) {
float32x4_t t[4];
static inline void shift_src(GI_FLOAT32_t rsrc[3][4]) {
GI_FLOAT32_t t[4];

t[0] = rsrc[0][(shift + 0) % 4];
t[1] = rsrc[0][(shift + 1) % 4];
@@ -63,9 +56,9 @@ static inline void shift_src(float32x4_t rsrc[3][4]) {
}

template <BiasMode bias_mode>
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) {
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) {
if (bias_mode == BiasMode::BIAS) {
return vld1q_f32(bias);
return GiLoadFloat32(bias);
} else {
return init;
}
@@ -76,35 +69,35 @@ struct compute_element {
template <typename Op>
static inline void call(
const float*& src0, const float*& src1, const float*& src2, float*& dst,
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], const Op& op) {
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4],
GI_FLOAT32_t rfilter[3][3], const Op& op) {
#define RSRC(i, j) rsrc[i][((j) + bw) % 4]
float32x4_t rdst = load_bias<bias_mode>(bias, init);
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);
if (has_top) {
RSRC(0, 3) = vld1q_f32(src0 + 8);
RSRC(0, 3) = GiLoadFloat32(src0 + 8);
}
{ RSRC(1, 3) = vld1q_f32(src1 + 8); }
{ RSRC(1, 3) = GiLoadFloat32(src1 + 8); }
if (has_bottom) {
RSRC(2, 3) = vld1q_f32(src2 + 8);
RSRC(2, 3) = GiLoadFloat32(src2 + 8);
}

if (has_top) {
rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]);
rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]);
rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]);
rdst = GiMlaqFloat32(rdst, RSRC(0, 0), rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, RSRC(0, 1), rfilter[0][1]);
rdst = GiMlaqFloat32(rdst, RSRC(0, 2), rfilter[0][2]);
}
{
rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]);
rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]);
rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]);
rdst = GiMlaqFloat32(rdst, RSRC(1, 0), rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, RSRC(1, 1), rfilter[1][1]);
rdst = GiMlaqFloat32(rdst, RSRC(1, 2), rfilter[1][2]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]);
rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]);
rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]);
rdst = GiMlaqFloat32(rdst, RSRC(2, 0), rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, RSRC(2, 1), rfilter[2][1]);
rdst = GiMlaqFloat32(rdst, RSRC(2, 2), rfilter[2][2]);
}

vst1q_f32(dst, op(rdst));
GiStoreFloat32(dst, op(rdst));

if (has_top) {
src0 += 4;
@@ -131,27 +124,27 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const float32x4_t& init,
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) {
float32x4_t rdst = load_bias<bias_mode>(bias, init);
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) {
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);

if (has_top) {
rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]);
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]);
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]);
rdst = GiMlaqFloat32(rdst, rsrc[0][0], rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][1]);
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][2]);
}
{
rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]);
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]);
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]);
rdst = GiMlaqFloat32(rdst, rsrc[1][0], rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][1]);
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][2]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]);
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]);
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]);
rdst = GiMlaqFloat32(rdst, rsrc[2][0], rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][1]);
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][2]);
}

vst1q_f32(dst, op(rdst));
GiStoreFloat32(dst, op(rdst));

dst += 4;
bias += 4;
@@ -162,24 +155,24 @@ template <bool has_top, bool has_bottom, BiasMode bias_mode>
struct compute_element_right_pad {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const float32x4_t& init,
float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) {
float32x4_t rdst = load_bias<bias_mode>(bias, init);
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[3][4], GI_FLOAT32_t rfilter[3][3], const Op& op) {
GI_FLOAT32_t rdst = load_bias<bias_mode>(bias, init);

if (has_top) {
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]);
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]);
rdst = GiMlaqFloat32(rdst, rsrc[0][1], rfilter[0][0]);
rdst = GiMlaqFloat32(rdst, rsrc[0][2], rfilter[0][1]);
}
{
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]);
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]);
rdst = GiMlaqFloat32(rdst, rsrc[1][1], rfilter[1][0]);
rdst = GiMlaqFloat32(rdst, rsrc[1][2], rfilter[1][1]);
}
if (has_bottom) {
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]);
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]);
rdst = GiMlaqFloat32(rdst, rsrc[2][1], rfilter[2][0]);
rdst = GiMlaqFloat32(rdst, rsrc[2][2], rfilter[2][1]);
}

vst1q_f32(dst, op(rdst));
GiStoreFloat32(dst, op(rdst));
dst += 4;
bias += 4;
}
@@ -190,22 +183,22 @@ struct compute_row {
template <typename Op>
static inline void call(
const float*& src0, const float*& src1, const float*& src2, float*& dst,
const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4],
float32x4_t rfilter[3][3], int W, const Op& op) {
const float*& bias, const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[3][4],
GI_FLOAT32_t rfilter[3][3], int W, const Op& op) {
if (has_top) {
rsrc[0][0] = vdupq_n_f32(0);
rsrc[0][1] = vld1q_f32(src0 + 0);
rsrc[0][2] = vld1q_f32(src0 + 4);
rsrc[0][0] = GiZeroFloat32();
rsrc[0][1] = GiLoadFloat32(src0 + 0);
rsrc[0][2] = GiLoadFloat32(src0 + 4);
}
{
rsrc[1][0] = vdupq_n_f32(0);
rsrc[1][1] = vld1q_f32(src1 + 0);
rsrc[1][2] = vld1q_f32(src1 + 4);
rsrc[1][0] = GiZeroFloat32();
rsrc[1][1] = GiLoadFloat32(src1 + 0);
rsrc[1][2] = GiLoadFloat32(src1 + 4);
}
if (has_bottom) {
rsrc[2][0] = vdupq_n_f32(0);
rsrc[2][1] = vld1q_f32(src2 + 0);
rsrc[2][2] = vld1q_f32(src2 + 4);
rsrc[2][0] = GiZeroFloat32();
rsrc[2][1] = GiLoadFloat32(src2 + 0);
rsrc[2][2] = GiLoadFloat32(src2 + 4);
}

int w = 0;
@@ -256,27 +249,27 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1(
int W) {
Op op;

float32x4_t init = vdupq_n_f32(0);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}

const float* src0 = src - W * 4;
const float* src1 = src;
const float* src2 = src + W * 4;

float32x4_t rfilter[3][3];
rfilter[0][0] = vld1q_f32(filter + 0);
rfilter[0][1] = vld1q_f32(filter + 4);
rfilter[0][2] = vld1q_f32(filter + 8);
rfilter[1][0] = vld1q_f32(filter + 12);
rfilter[1][1] = vld1q_f32(filter + 16);
rfilter[1][2] = vld1q_f32(filter + 20);
rfilter[2][0] = vld1q_f32(filter + 24);
rfilter[2][1] = vld1q_f32(filter + 28);
rfilter[2][2] = vld1q_f32(filter + 32);
float32x4_t rsrc[3][4];
GI_FLOAT32_t rfilter[3][3];
rfilter[0][0] = GiLoadFloat32(filter + 0);
rfilter[0][1] = GiLoadFloat32(filter + 4);
rfilter[0][2] = GiLoadFloat32(filter + 8);
rfilter[1][0] = GiLoadFloat32(filter + 12);
rfilter[1][1] = GiLoadFloat32(filter + 16);
rfilter[1][2] = GiLoadFloat32(filter + 20);
rfilter[2][0] = GiLoadFloat32(filter + 24);
rfilter[2][1] = GiLoadFloat32(filter + 28);
rfilter[2][2] = GiLoadFloat32(filter + 32);
GI_FLOAT32_t rsrc[3][4];

compute_row<false, true, bias_mode>::call(
src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op);

dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,11 +12,11 @@

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace channel_wise_nchw44_float {

template <BiasMode bias_mode, typename Op>
@@ -25,7 +25,7 @@ void do_conv_kern_3x3_stride1_padding1(
int W);

} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,29 +10,22 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#pragma GCC diagnostic ignored "-Wunused-parameter"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

#if defined(__ARM_FEATURE_FMA)
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m)
#else
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m)
#endif

template <int shift>
static inline void shift_src(float32x4_t rsrc[6]) {
float32x4_t t[6];
static inline void shift_src(GI_FLOAT32_t rsrc[6]) {
GI_FLOAT32_t t[6];

t[0] = rsrc[(shift + 0) % 6];
t[1] = rsrc[(shift + 1) % 6];
@@ -48,18 +41,18 @@ static inline void shift_src(float32x4_t rsrc[6]) {
rsrc[5] = t[5];
}

static inline void load_filter(const float* filter, float32x4_t rfilter[5]) {
rfilter[0] = vld1q_f32(filter + 0);
rfilter[1] = vld1q_f32(filter + 4);
rfilter[2] = vld1q_f32(filter + 8);
rfilter[3] = vld1q_f32(filter + 12);
rfilter[4] = vld1q_f32(filter + 16);
static inline void load_filter(const float* filter, GI_FLOAT32_t rfilter[5]) {
rfilter[0] = GiLoadFloat32(filter + 0);
rfilter[1] = GiLoadFloat32(filter + 4);
rfilter[2] = GiLoadFloat32(filter + 8);
rfilter[3] = GiLoadFloat32(filter + 12);
rfilter[4] = GiLoadFloat32(filter + 16);
}

template <BiasMode bias_mode>
static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) {
static inline GI_FLOAT32_t load_bias(const float* bias, const GI_FLOAT32_t& init) {
if (bias_mode == BiasMode::BIAS) {
return vld1q_f32(bias);
return GiLoadFloat32(bias);
} else {
return init;
}
@@ -69,27 +62,28 @@ template <int BW, int bw, BiasMode bias_mode, bool need_load_bias, bool need_do_
struct compute_element {
template <typename Op>
static inline void call(
const float*& src, float*& dst, const float*& bias, const float32x4_t& init,
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) {
const float*& src, float*& dst, const float*& bias,
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5],
const Op& op) {
#define RSRC(i) rsrc[((i) + bw) % 6]
float32x4_t rdst;
GI_FLOAT32_t rdst;
if (need_load_bias) {
rdst = load_bias<bias_mode>(bias, init);
} else {
rdst = vld1q_f32(dst);
rdst = GiLoadFloat32(dst);
}
RSRC(5) = vld1q_f32(src + 12);
RSRC(5) = GiLoadFloat32(src + 12);

rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]);
rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]);
rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]);
rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]);
rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]);
rdst = GiMlaqFloat32(rdst, RSRC(0), rfilter[0]);
rdst = GiMlaqFloat32(rdst, RSRC(1), rfilter[1]);
rdst = GiMlaqFloat32(rdst, RSRC(2), rfilter[2]);
rdst = GiMlaqFloat32(rdst, RSRC(3), rfilter[3]);
rdst = GiMlaqFloat32(rdst, RSRC(4), rfilter[4]);

if (need_do_op) {
rdst = op(rdst);
}
vst1q_f32(dst, rdst);
GiStoreFloat32(dst, rdst);

src += 4;
dst += 4;
@@ -110,29 +104,29 @@ template <size_t padding, BiasMode bias_mode, bool need_load_bias, bool need_do_
struct compute_element_right {
template <typename Op>
static inline void call(
float*& dst, const float*& bias, const float32x4_t& init,
float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) {
float32x4_t rdst;
float*& dst, const float*& bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], const Op& op) {
GI_FLOAT32_t rdst;
if (need_load_bias) {
rdst = load_bias<bias_mode>(bias, init);
} else {
rdst = vld1q_f32(dst);
rdst = GiLoadFloat32(dst);
}

rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]);
rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]);
rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]);
rdst = GiMlaqFloat32(rdst, rsrc[0 + padding], rfilter[0]);
rdst = GiMlaqFloat32(rdst, rsrc[1 + padding], rfilter[1]);
rdst = GiMlaqFloat32(rdst, rsrc[2 + padding], rfilter[2]);
if (padding < 2) {
rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]);
rdst = GiMlaqFloat32(rdst, rsrc[3 + padding], rfilter[3]);
}
if (padding < 1) {
rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]);
rdst = GiMlaqFloat32(rdst, rsrc[4 + padding], rfilter[4]);
}

if (need_do_op) {
rdst = op(rdst);
}
vst1q_f32(dst, rdst);
GiStoreFloat32(dst, rdst);

dst += 4;
bias += 4;
@@ -143,13 +137,13 @@ template <BiasMode bias_mode, bool need_load_bias, bool need_do_op>
struct compute_row_src_1x5 {
template <typename Op>
static inline void call(
const float* src, float* dst, const float* bias, const float32x4_t& init,
float32x4_t rsrc[6], float32x4_t rfilter[5], int W, const Op& op) {
rsrc[0] = vdupq_n_f32(0);
rsrc[1] = vdupq_n_f32(0);
rsrc[2] = vld1q_f32(src + 0);
rsrc[3] = vld1q_f32(src + 4);
rsrc[4] = vld1q_f32(src + 8);
const float* src, float* dst, const float* bias, const GI_FLOAT32_t& init,
GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5], int W, const Op& op) {
rsrc[0] = GiZeroFloat32();
rsrc[1] = GiZeroFloat32();
rsrc[2] = GiLoadFloat32(src + 0);
rsrc[3] = GiLoadFloat32(src + 4);
rsrc[4] = GiLoadFloat32(src + 8);

int w = 0;

@@ -190,8 +184,8 @@ struct compute_row {
template <typename Op>
static inline void call(
const float*& src, float*& dst, const float* filter, const float*& bias,
const float32x4_t& init, float32x4_t rsrc[6], float32x4_t rfilter[5], int W,
const Op& op) {
const GI_FLOAT32_t& init, GI_FLOAT32_t rsrc[6], GI_FLOAT32_t rfilter[5],
int W, const Op& op) {
if (top_padding < 1) {
load_filter(filter + 0, rfilter);
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call(
@@ -235,13 +229,13 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2(
int W) {
Op op;

float32x4_t init = vdupq_n_f32(0);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}

float32x4_t rsrc[6];
float32x4_t rfilter[5];
GI_FLOAT32_t rsrc[6];
GI_FLOAT32_t rfilter[5];

compute_row<2, 0, bias_mode>::call(
src, dst, filter, bias, init, rsrc, rfilter, W, op);

dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,11 +12,11 @@

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace channel_wise_nchw44_float {

template <BiasMode bias_mode, typename Op>
@@ -25,7 +25,7 @@ void do_conv_kern_5x5_stride1_padding2(
int W);

} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,14 +10,14 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/conv_bias/gi/fp32/algos.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "midout.h"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
using conv_fun = std::function<void(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,

dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct.cpp
* \file dnn/src/fallback/conv_bias/int8/direct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,29 +10,28 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_3x3_s1p1_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/fp32/channel_wise_5x5_s1p2_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

template <int size>
void load_vec(float32x4_t* dst, const float* src);
void load_vec(GI_FLOAT32_t* dst, const float* src);

#define cb(i) dst[i] = vld1q_f32(src + i * 4);
#define LOAD_MACRO(n) \
template <> \
inline void load_vec<n>(float32x4_t * dst, const float* src) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
#define cb(i) dst[i] = GiLoadFloat32(src + i * 4);
#define LOAD_MACRO(n) \
template <> \
inline void load_vec<n>(GI_FLOAT32_t * dst, const float* src) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
}
LOAD_MACRO(2);
LOAD_MACRO(3);
@@ -46,14 +45,14 @@ LOAD_MACRO(9);
#undef LOAD_MACRO

template <int size>
void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter);
void compute_vec(GI_FLOAT32_t& dst, GI_FLOAT32_t* src, GI_FLOAT32_t* filter);

#define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]);
#define COMPUTE_MACRO(n) \
template <> \
inline void compute_vec<n>( \
float32x4_t & dst, float32x4_t * src, float32x4_t * filter) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
#define cb(i) dst = GiMlaqFloat32(dst, src[i], filter[i]);
#define COMPUTE_MACRO(n) \
template <> \
inline void compute_vec<n>( \
GI_FLOAT32_t & dst, GI_FLOAT32_t * src, GI_FLOAT32_t * filter) { \
UNROLL_CALL_NOWRAPPER(n, cb); \
}
COMPUTE_MACRO(2);
COMPUTE_MACRO(3);
@@ -64,20 +63,20 @@ COMPUTE_MACRO(5);
template <BiasMode bias_mode, int size>
struct load_bias_vec;

#define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4);
#define cb_bias(i) dst[i] = GiLoadFloat32((bptr) + i * 4);
#define cb_init(i) dst[i] = init;

#define INIT_BIAS_MACRO(n) \
template <BiasMode bias_mode> \
struct load_bias_vec<bias_mode, n> { \
static void impl( \
float32x4_t* dst, const float32x4_t& init, const float* bptr) { \
if (bias_mode == BiasMode::BIAS) { \
UNROLL_CALL_NOWRAPPER(n, cb_bias); \
} else { \
UNROLL_CALL_NOWRAPPER(n, cb_init); \
} \
} \
#define INIT_BIAS_MACRO(n) \
template <BiasMode bias_mode> \
struct load_bias_vec<bias_mode, n> { \
static void impl( \
GI_FLOAT32_t* dst, const GI_FLOAT32_t& init, const float* bptr) { \
if (bias_mode == BiasMode::BIAS) { \
UNROLL_CALL_NOWRAPPER(n, cb_bias); \
} else { \
UNROLL_CALL_NOWRAPPER(n, cb_init); \
} \
} \
};

INIT_BIAS_MACRO(1);
@@ -91,7 +90,7 @@ INIT_BIAS_MACRO(4);
#define COMPUTE_PADDING_KERNEL() \
do { \
int iw = ow * stride - PW; \
float32x4_t result; \
GI_FLOAT32_t result; \
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4 + ow * 4); \
for (int kh = 0; kh < fh; kh++) { \
if (kh + ih < 0 || kh + ih >= static_cast<int>(IH)) \
@@ -100,7 +99,8 @@ INIT_BIAS_MACRO(4);
if (kw + iw < 0 || kw + iw >= static_cast<int>(IW)) \
continue; \
const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \
result = vmlaq_f32(result, kernel[kh * fh + kw], vld1q_f32(sptr)); \
result = GiMlaqFloat32( \
result, kernel[kh * fh + kw], GiLoadFloat32(sptr)); \
} \
} \
float* output = dst + oh * OW * 4 + ow * 4; \
@@ -113,7 +113,7 @@ struct PaddingCompute {
const float* src, const float* bias, float* dst, const int fh,
const int stride, const size_t IH, const size_t IW, const size_t OH,
const size_t OW, const size_t PH, const size_t PW,
const float32x4_t* kernel, const float32x4_t& init) {
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) {
size_t oh_start = (PH + stride - 1) / stride;
size_t ow_start = (PW + stride - 1) / stride;
size_t oh_end = (IH + PH - fh) / stride + 1;
@@ -148,7 +148,7 @@ struct PaddingComputeK3P1 {
static void compute(
const float* src, const float* bias, float* dst, const size_t stride,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const float32x4_t* kernel, const float32x4_t& init) {
const GI_FLOAT32_t* kernel, const GI_FLOAT32_t& init) {
constexpr size_t PH = 1, PW = 1, FH = 3;
size_t oh_start = (PH + stride - 1) / stride;
size_t ow_start = (PW + stride - 1) / stride;
@@ -162,39 +162,39 @@ struct PaddingComputeK3P1 {
Op op;
// line one left
{
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias);
result = vmlaq_f32(result, kernel[4], vld1q_f32(src));
result = vmlaq_f32(result, kernel[5], vld1q_f32(src + 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(src + IW * 4));
result = vmlaq_f32(result, kernel[8], vld1q_f32(src + IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(src));
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(src + 4));
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(src + IW * 4));
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(src + IW * 4 + 4));
float* output = dst;
op(result, output);
}
// line one mid
for (size_t ow = ow_start; ow < ow_end; ow++) {
int iw = ow * stride - PW;
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + ow * 4);
const float* sptr = src + iw * 4;
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + 8));
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + IW * 4 + 8));
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[5], GiLoadFloat32(sptr + 8));
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[8], GiLoadFloat32(sptr + IW * 4 + 8));
float* output = dst + ow * 4;
op(result, output);
}
// line one right
if (OW != ow_end) {
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + (OW - 1) * 4);
const float* sptr = src + (ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[6], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(result, kernel[7], GiLoadFloat32(sptr + IW * 4 + 4));
float* output = dst + ow_end * 4;
op(result, output);
}
@@ -203,30 +203,36 @@ struct PaddingComputeK3P1 {
int ih = oh * stride - PH;
// left
{
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4);
const float* sptr = src + ih * IW * 4;
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4));
result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + 2 * IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4));
result = GiMlaqFloat32(
result, kernel[8], GiLoadFloat32(sptr + 2 * IW * 4 + 4));
float* output = dst + oh * OW * 4;
op(result, output);
}
// right
if (OW != ow_end) {
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(
&result, init, bias + oh * OW * 4 + (OW - 1) * 4);
const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + 2 * IW * 4));
result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(
result, kernel[6], GiLoadFloat32(sptr + 2 * IW * 4));
result = GiMlaqFloat32(
result, kernel[7], GiLoadFloat32(sptr + 2 * IW * 4 + 4));
float* output = dst + oh * OW * 4 + ow_end * 4;
op(result, output);
}
@@ -235,43 +241,47 @@ struct PaddingComputeK3P1 {
if (OH != oh_end) {
size_t oh = OH - 1;
{
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(&result, init, bias + oh * OW * 4);
const float* sptr = src + (oh_end * stride - PH) * IW * 4;
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[4], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 4));
float* output = dst + oh_end * OW * 4;
op(result, output);
}
// last line mid
for (size_t ow = ow_start; ow < ow_end; ow++) {
int iw = ow * stride - PW;
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(
&result, init, bias + oh * OW * 4 + ow * 4);
const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4));
result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 8));
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[2], GiLoadFloat32(sptr + 8));
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(
result, kernel[5], GiLoadFloat32(sptr + IW * 4 + 8));
float* output = dst + oh_end * OW * 4 + ow * 4;
op(result, output);
}
// last line right
if (OW != ow_end) {
float32x4_t result;
GI_FLOAT32_t result;
load_bias_vec<bias_mode, 1>::impl(
&result, init, bias + oh * OW * 4 + (OW - 1) * 4);
const float* sptr = src + (oh_end * stride - PH) * IW * 4 +
(ow_end * stride - PW) * 4;
result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr));
result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4));
result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4));
result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4));
result = GiMlaqFloat32(result, kernel[0], GiLoadFloat32(sptr));
result = GiMlaqFloat32(result, kernel[1], GiLoadFloat32(sptr + 4));
result = GiMlaqFloat32(result, kernel[3], GiLoadFloat32(sptr + IW * 4));
result = GiMlaqFloat32(
result, kernel[4], GiLoadFloat32(sptr + IW * 4 + 4));
float* output = dst + oh_end * OW * 4 + ow_end * 4;
op(result, output);
}
@@ -286,12 +296,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[4];
GI_FLOAT32_t kernel[4];
load_vec<4>(kernel, filter);
Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
@@ -315,12 +325,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][4];
GI_FLOAT32_t dst_v[2][4];
load_bias_vec<bias_mode, 4>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 4>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[3][5];
GI_FLOAT32_t src_v[3][5];
load_vec<5>(src_v[0], input);
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
@@ -338,12 +348,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
GI_FLOAT32_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[3][2];
GI_FLOAT32_t src_v[3][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
@@ -363,10 +373,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[1][4];
GI_FLOAT32_t dst_v[1][4];
load_bias_vec<bias_mode, 4>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][5];
GI_FLOAT32_t src_v[2][5];
load_vec<5>(src_v[0], input);
COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]);
load_vec<5>(src_v[1], input + IW * 4);
@@ -379,10 +389,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2(
size_t iw = ow - ow_start;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][2];
GI_FLOAT32_t src_v[2][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
@@ -405,12 +415,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
return;
}

float32x4_t kernel[9];
GI_FLOAT32_t kernel[9];
load_vec<9>(kernel, filter);
Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
@@ -428,12 +438,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][4];
GI_FLOAT32_t dst_v[2][4];
load_bias_vec<bias_mode, 4>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 4>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][6];
GI_FLOAT32_t src_v[2][6];
load_vec<6>(src_v[0], input);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]);
@@ -472,12 +482,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
GI_FLOAT32_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][3];
GI_FLOAT32_t src_v[2][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
@@ -500,10 +510,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[4];
GI_FLOAT32_t dst_v[4];
load_bias_vec<bias_mode, 4>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][6];
GI_FLOAT32_t src_v[2][6];
load_vec<6>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[1], &src_v[0][1], &kernel[0]);
@@ -526,10 +536,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][3];
GI_FLOAT32_t src_v[3][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
@@ -553,9 +563,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
}

Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
size_t oh_start = PH;
size_t ow_start = PW;
@@ -564,7 +574,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(
src, bias, dst, 5, 1, IH, IW, OH, OW, PH, PW,
reinterpret_cast<const float32x4_t*>(filter), init);
reinterpret_cast<const GI_FLOAT32_t*>(filter), init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
@@ -574,13 +584,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
GI_FLOAT32_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][6];
GI_FLOAT32_t kernel[2][5];
GI_FLOAT32_t src_v[2][6];
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
@@ -613,13 +623,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][1];
GI_FLOAT32_t dst_v[2][1];
load_bias_vec<bias_mode, 1>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
GI_FLOAT32_t kernel[2][5];
GI_FLOAT32_t src_v[2][5];
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
@@ -652,11 +662,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[1][2];
GI_FLOAT32_t dst_v[1][2];
load_bias_vec<bias_mode, 2>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][6];
GI_FLOAT32_t kernel[2][5];
GI_FLOAT32_t src_v[2][6];
#define COMPUTE_5X5_2(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
@@ -679,11 +689,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5(
size_t iw = ow - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
GI_FLOAT32_t kernel[2][5];
GI_FLOAT32_t src_v[2][5];
#define COMPUTE_5X5_1(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \
@@ -709,12 +719,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[4];
GI_FLOAT32_t kernel[4];
load_vec<4>(kernel, filter);
Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
size_t oh_start = (PH + 1) / 2;
size_t ow_start = (PW + 1) / 2;
@@ -737,10 +747,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[4];
GI_FLOAT32_t dst_v[4];
load_bias_vec<bias_mode, 4>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][8];
GI_FLOAT32_t src_v[2][8];
load_vec<8>(src_v[0], input);
COMPUTE_2X2(dst_v, src_v[0], &kernel[0]);
load_vec<8>(src_v[1], input + IW * 4);
@@ -753,10 +763,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[2][2];
GI_FLOAT32_t src_v[2][2];
load_vec<2>(src_v[0], input);
compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<2>(src_v[1], input + IW * 4);
@@ -773,12 +783,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
const float* src, const float* filter, const float* bias, float* dst,
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
float32x4_t kernel[9];
GI_FLOAT32_t kernel[9];
load_vec<9>(kernel, filter);
Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
size_t oh_start = (PH + 1) / 2;
size_t ow_start = (PW + 1) / 2;
@@ -799,12 +809,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
GI_FLOAT32_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][5];
GI_FLOAT32_t src_v[2][5];
load_vec<5>(src_v[0], input);
compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]);
@@ -830,12 +840,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
GI_FLOAT32_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t src_v[2][3];
GI_FLOAT32_t src_v[2][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
@@ -859,10 +869,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
GI_FLOAT32_t dst_v[2];
load_bias_vec<bias_mode, 2>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][5];
GI_FLOAT32_t src_v[3][5];
load_vec<5>(src_v[0], input);
compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]);
compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]);
@@ -878,10 +888,10 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3(
size_t iw = ow * 2 - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t src_v[3][3];
GI_FLOAT32_t src_v[3][3];
load_vec<3>(src_v[0], input);
compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]);
load_vec<3>(src_v[1], input + IW * 4);
@@ -899,9 +909,9 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
const size_t IH, const size_t IW, const size_t OH, const size_t OW,
const size_t PH, const size_t PW) {
Op op;
float32x4_t init = vdupq_n_f32(0.f);
GI_FLOAT32_t init = GiZeroFloat32();
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
init = vld1q_f32(bias);
init = GiLoadFloat32(bias);
}
constexpr size_t stride = 2;
size_t oh_start = (PH + stride - 1) / stride;
@@ -911,7 +921,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
if (PH || PW) {
PaddingCompute<bias_mode, Op>::compute(
src, bias, dst, 5, stride, IH, IW, OH, OW, PH, PW,
reinterpret_cast<const float32x4_t*>(filter), init);
reinterpret_cast<const GI_FLOAT32_t*>(filter), init);
}
size_t oh = oh_start;
for (; oh + 1 < oh_end; oh += 2) {
@@ -921,13 +931,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2][2];
GI_FLOAT32_t dst_v[2][2];
load_bias_vec<bias_mode, 2>::impl(
dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 2>::impl(
dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[3][5];
float32x4_t src_v[2][7];
GI_FLOAT32_t kernel[3][5];
GI_FLOAT32_t src_v[2][7];
#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<7>(src, input + i * IW * 4); \
@@ -965,13 +975,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v[2];
GI_FLOAT32_t dst_v[2];
load_bias_vec<bias_mode, 1>::impl(
&dst_v[0], init, bias + oh * OW * 4 + ow * 4);
load_bias_vec<bias_mode, 1>::impl(
&dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4);
float32x4_t kernel[3][5];
float32x4_t src_v[2][5];
GI_FLOAT32_t kernel[3][5];
GI_FLOAT32_t src_v[2][5];
#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \
load_vec<5>(kernel0, filter + i * 5 * 4); \
load_vec<5>(src, input + i * IW * 4); \
@@ -1010,11 +1020,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5(
size_t iw = ow * stride - PW;
const float* input = src + ih * IW * 4 + iw * 4;
float* output = dst + oh * OW * 4 + ow * 4;
float32x4_t dst_v;
GI_FLOAT32_t dst_v;
load_bias_vec<bias_mode, 1>::impl(
&dst_v, init, bias + oh * OW * 4 + ow * 4);
float32x4_t kernel[2][5];
float32x4_t src_v[2][5];
GI_FLOAT32_t kernel[2][5];
GI_FLOAT32_t src_v[2][5];
#define COMPUTE_5X5_1(i, dst, src, kernel) \
load_vec<5>(kernel, filter + i * 5 * 4); \
load_vec<6>(src, input + i * IW * 4); \

dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h → dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/channel_wise_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,11 +12,11 @@

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace channel_wise_nchw44_float {

#define KERN(stride, i) \
@@ -37,7 +37,7 @@ KERN(stride2, 5)
#undef KERN

} // namespace channel_wise_nchw44_float
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/direct.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,18 +9,18 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/direct.h"
#include "src/fallback/conv_bias/gi/fp32/direct.h"
#include <cstring>
#include "include/megdnn/oprs.h"
#include "midout.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
MIDOUT_DECL(megdnn_arm_conv_f32)
#include "src/fallback/conv_bias/gi/postprocess_helper.h"
#include "src/fallback/general_intrinsic/gi_float.h"
MIDOUT_DECL(megdnn_gi_conv_f32)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
using namespace fp32;
using namespace conv_bias;

@@ -34,65 +34,65 @@ struct do_pixel_proxy {
const int ow);
};

#define cb_load(i) data = vld1q_lane_f32(dst + i, data, i);
#define LOAD_OUT \
if (width < 4) { \
auto load_less_4 = [](float* dst, float32x4_t& data) { \
if (width == 1u) { \
UNROLL_CALL_NOWRAPPER(1, cb_load); \
} else if (width == 2u) { \
UNROLL_CALL_NOWRAPPER(2, cb_load); \
} else if (width == 3u) { \
UNROLL_CALL_NOWRAPPER(3, cb_load); \
} \
}; \
if (height >= 1) \
load_less_4(dst + 0 * OW, out0); \
if (height >= 2) \
load_less_4(dst + 1 * OW, out1); \
if (height >= 3) \
load_less_4(dst + 2 * OW, out2); \
if (height >= 4) \
load_less_4(dst + 3 * OW, out3); \
} else { \
if (height > 0) \
out0 = vld1q_f32(dst + 0 * OW); \
if (height > 1) \
out1 = vld1q_f32(dst + 1 * OW); \
if (height > 2) \
out2 = vld1q_f32(dst + 2 * OW); \
if (height > 3) \
out3 = vld1q_f32(dst + 3 * OW); \
}
#define cb_store(i) vst1q_lane_f32(dst + i, data, i);
#define STORE_OUT \
#define cb_load(i) data = GiLd1qLaneFloat32(dst + i, data, i);
#define LOAD_OUT \
if (width < 4) { \
auto store_less_4 = [](float* dst, float32x4_t& data) { \
auto load_less_4 = [](float* dst, GI_FLOAT32_t& data) { \
if (width == 1u) { \
UNROLL_CALL_NOWRAPPER(1, cb_store); \
UNROLL_CALL_NOWRAPPER(1, cb_load); \
} else if (width == 2u) { \
UNROLL_CALL_NOWRAPPER(2, cb_store); \
UNROLL_CALL_NOWRAPPER(2, cb_load); \
} else if (width == 3u) { \
UNROLL_CALL_NOWRAPPER(3, cb_store); \
UNROLL_CALL_NOWRAPPER(3, cb_load); \
} \
}; \
if (height >= 1) \
store_less_4(dst + 0 * OW, out0); \
load_less_4(dst + 0 * OW, out0); \
if (height >= 2) \
store_less_4(dst + 1 * OW, out1); \
load_less_4(dst + 1 * OW, out1); \
if (height >= 3) \
store_less_4(dst + 2 * OW, out2); \
load_less_4(dst + 2 * OW, out2); \
if (height >= 4) \
store_less_4(dst + 3 * OW, out3); \
load_less_4(dst + 3 * OW, out3); \
} else { \
if (height >= 1) \
vst1q_f32(dst + 0 * OW, out0); \
if (height >= 2) \
vst1q_f32(dst + 1 * OW, out1); \
if (height >= 3) \
vst1q_f32(dst + 2 * OW, out2); \
if (height >= 4) \
vst1q_f32(dst + 3 * OW, out3); \
if (height > 0) \
out0 = GiLoadFloat32(dst + 0 * OW); \
if (height > 1) \
out1 = GiLoadFloat32(dst + 1 * OW); \
if (height > 2) \
out2 = GiLoadFloat32(dst + 2 * OW); \
if (height > 3) \
out3 = GiLoadFloat32(dst + 3 * OW); \
}
#define cb_store(i) GiStoreLane##i##Float32(dst + i, data);
#define STORE_OUT \
if (width < 4) { \
auto store_less_4 = [](float* dst, GI_FLOAT32_t& data) { \
if (width == 1u) { \
UNROLL_CALL_NOWRAPPER(1, cb_store); \
} else if (width == 2u) { \
UNROLL_CALL_NOWRAPPER(2, cb_store); \
} else if (width == 3u) { \
UNROLL_CALL_NOWRAPPER(3, cb_store); \
} \
}; \
if (height >= 1) \
store_less_4(dst + 0 * OW, out0); \
if (height >= 2) \
store_less_4(dst + 1 * OW, out1); \
if (height >= 3) \
store_less_4(dst + 2 * OW, out2); \
if (height >= 4) \
store_less_4(dst + 3 * OW, out3); \
} else { \
if (height >= 1) \
GiStoreFloat32(dst + 0 * OW, out0); \
if (height >= 2) \
GiStoreFloat32(dst + 1 * OW, out1); \
if (height >= 3) \
GiStoreFloat32(dst + 2 * OW, out2); \
if (height >= 4) \
GiStoreFloat32(dst + 3 * OW, out3); \
}

template <int height, int width>
@@ -104,33 +104,33 @@ struct do_pixel_proxy<1, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 3)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);
}
STORE_OUT;
}
@@ -145,45 +145,45 @@ struct do_pixel_proxy<2, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 3)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);
}
STORE_OUT;
}
@@ -198,57 +198,57 @@ struct do_pixel_proxy<3, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
out0 = GiMlaqFloat32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
out1 = GiMlaqFloat32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 2)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
out2 = GiMlaqFloat32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);

if (height > 3)
inp = vld1q_f32(src_dd + 5 * IW);
inp = GiLoadFloat32(src_dd + 5 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
out3 = GiMlaqFloat32(out3, inp, kr2);
}
STORE_OUT;
}
@@ -263,69 +263,69 @@ struct do_pixel_proxy<4, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]);
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
out0 = GiMlaqFloat32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
out0 = GiMlaqFloat32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
out1 = GiMlaqFloat32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 1)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
out1 = GiMlaqFloat32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
out2 = GiMlaqFloat32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);

if (height > 2)
inp = vld1q_f32(src_dd + 5 * IW);
inp = GiLoadFloat32(src_dd + 5 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
out2 = GiMlaqFloat32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
out3 = GiMlaqFloat32(out3, inp, kr2);

if (height > 3)
inp = vld1q_f32(src_dd + 6 * IW);
inp = GiLoadFloat32(src_dd + 6 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);
out3 = GiMlaqFloat32(out3, inp, kr3);
}
STORE_OUT;
}
@@ -340,81 +340,81 @@ struct do_pixel_proxy<5, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp;
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]);
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]);
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
out0 = GiMlaqFloat32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
out0 = GiMlaqFloat32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
out1 = GiMlaqFloat32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
out0 = GiMlaqFloat32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
out1 = GiMlaqFloat32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
out2 = GiMlaqFloat32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);

if (height > 1)
inp = vld1q_f32(src_dd + 5 * IW);
inp = GiLoadFloat32(src_dd + 5 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
out1 = GiMlaqFloat32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
out2 = GiMlaqFloat32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
out3 = GiMlaqFloat32(out3, inp, kr2);

if (height > 2)
inp = vld1q_f32(src_dd + 6 * IW);
inp = GiLoadFloat32(src_dd + 6 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
out2 = GiMlaqFloat32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);
out3 = GiMlaqFloat32(out3, inp, kr3);

if (height > 3)
inp = vld1q_f32(src_dd + 7 * IW);
inp = GiLoadFloat32(src_dd + 7 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);
out3 = GiMlaqFloat32(out3, inp, kr4);
}
STORE_OUT;
}
@@ -429,94 +429,94 @@ struct do_pixel_proxy<6, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);
kr5 = vdupq_n_f32(filter[5 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]);
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]);
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]);
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
out0 = GiMlaqFloat32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
out0 = GiMlaqFloat32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
out1 = GiMlaqFloat32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
out0 = GiMlaqFloat32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
out1 = GiMlaqFloat32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
out2 = GiMlaqFloat32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);

if (height > 0)
inp = vld1q_f32(src_dd + 5 * IW);
inp = GiLoadFloat32(src_dd + 5 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr5);
out0 = GiMlaqFloat32(out0, inp, kr5);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
out1 = GiMlaqFloat32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
out2 = GiMlaqFloat32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
out3 = GiMlaqFloat32(out3, inp, kr2);

if (height > 1)
inp = vld1q_f32(src_dd + 6 * IW);
inp = GiLoadFloat32(src_dd + 6 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr5);
out1 = GiMlaqFloat32(out1, inp, kr5);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
out2 = GiMlaqFloat32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);
out3 = GiMlaqFloat32(out3, inp, kr3);

if (height > 2)
inp = vld1q_f32(src_dd + 7 * IW);
inp = GiLoadFloat32(src_dd + 7 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr5);
out2 = GiMlaqFloat32(out2, inp, kr5);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);
out3 = GiMlaqFloat32(out3, inp, kr4);

if (height > 3)
inp = vld1q_f32(src_dd + 8 * IW);
inp = GiLoadFloat32(src_dd + 8 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr5);
out3 = GiMlaqFloat32(out3, inp, kr5);
}
STORE_OUT;
}
@@ -531,106 +531,106 @@ struct do_pixel_proxy<7, height, width> {
(void)IH;
(void)OH;
const int ih = oh, iw = ow;
float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
GI_FLOAT32_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5,
kr6, inp;
src += ih * IW + iw;
dst += oh * OW + ow;
LOAD_OUT;
for (int fw = 0; fw < FW; ++fw) {
const float* src_dd = src + fw;
kr0 = vdupq_n_f32(filter[0 * FW + fw]);
kr1 = vdupq_n_f32(filter[1 * FW + fw]);
kr2 = vdupq_n_f32(filter[2 * FW + fw]);
kr3 = vdupq_n_f32(filter[3 * FW + fw]);
kr4 = vdupq_n_f32(filter[4 * FW + fw]);
kr5 = vdupq_n_f32(filter[5 * FW + fw]);
kr6 = vdupq_n_f32(filter[6 * FW + fw]);
kr0 = GiBroadcastFloat32(filter[0 * FW + fw]);
kr1 = GiBroadcastFloat32(filter[1 * FW + fw]);
kr2 = GiBroadcastFloat32(filter[2 * FW + fw]);
kr3 = GiBroadcastFloat32(filter[3 * FW + fw]);
kr4 = GiBroadcastFloat32(filter[4 * FW + fw]);
kr5 = GiBroadcastFloat32(filter[5 * FW + fw]);
kr6 = GiBroadcastFloat32(filter[6 * FW + fw]);

if (height > 0)
inp = vld1q_f32(src_dd + 0 * IW);
inp = GiLoadFloat32(src_dd + 0 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr0);
out0 = GiMlaqFloat32(out0, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 1 * IW);
inp = GiLoadFloat32(src_dd + 1 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr1);
out0 = GiMlaqFloat32(out0, inp, kr1);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr0);
out1 = GiMlaqFloat32(out1, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 2 * IW);
inp = GiLoadFloat32(src_dd + 2 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr2);
out0 = GiMlaqFloat32(out0, inp, kr2);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr1);
out1 = GiMlaqFloat32(out1, inp, kr1);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr0);
out2 = GiMlaqFloat32(out2, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 3 * IW);
inp = GiLoadFloat32(src_dd + 3 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr3);
out0 = GiMlaqFloat32(out0, inp, kr3);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr2);
out1 = GiMlaqFloat32(out1, inp, kr2);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr1);
out2 = GiMlaqFloat32(out2, inp, kr1);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr0);
out3 = GiMlaqFloat32(out3, inp, kr0);

if (height > 0)
inp = vld1q_f32(src_dd + 4 * IW);
inp = GiLoadFloat32(src_dd + 4 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr4);
out0 = GiMlaqFloat32(out0, inp, kr4);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr3);
out1 = GiMlaqFloat32(out1, inp, kr3);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr2);
out2 = GiMlaqFloat32(out2, inp, kr2);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr1);
out3 = GiMlaqFloat32(out3, inp, kr1);

if (height > 0)
inp = vld1q_f32(src_dd + 5 * IW);
inp = GiLoadFloat32(src_dd + 5 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr5);
out0 = GiMlaqFloat32(out0, inp, kr5);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr4);
out1 = GiMlaqFloat32(out1, inp, kr4);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr3);
out2 = GiMlaqFloat32(out2, inp, kr3);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr2);
out3 = GiMlaqFloat32(out3, inp, kr2);

if (height > 0)
inp = vld1q_f32(src_dd + 6 * IW);
inp = GiLoadFloat32(src_dd + 6 * IW);
if (height > 0)
out0 = vmlaq_f32(out0, inp, kr6);
out0 = GiMlaqFloat32(out0, inp, kr6);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr5);
out1 = GiMlaqFloat32(out1, inp, kr5);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr4);
out2 = GiMlaqFloat32(out2, inp, kr4);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr3);
out3 = GiMlaqFloat32(out3, inp, kr3);

if (height > 1)
inp = vld1q_f32(src_dd + 7 * IW);
inp = GiLoadFloat32(src_dd + 7 * IW);
if (height > 1)
out1 = vmlaq_f32(out1, inp, kr6);
out1 = GiMlaqFloat32(out1, inp, kr6);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr5);
out2 = GiMlaqFloat32(out2, inp, kr5);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr4);
out3 = GiMlaqFloat32(out3, inp, kr4);

if (height > 2)
inp = vld1q_f32(src_dd + 8 * IW);
inp = GiLoadFloat32(src_dd + 8 * IW);
if (height > 2)
out2 = vmlaq_f32(out2, inp, kr6);
out2 = GiMlaqFloat32(out2, inp, kr6);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr5);
out3 = GiMlaqFloat32(out3, inp, kr5);

if (height > 3)
inp = vld1q_f32(src_dd + 9 * IW);
inp = GiLoadFloat32(src_dd + 9 * IW);
if (height > 3)
out3 = vmlaq_f32(out3, inp, kr6);
out3 = GiMlaqFloat32(out3, inp, kr6);
}
STORE_OUT;
}
@@ -836,31 +836,31 @@ void conv_bias::kern_direct(
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}
@@ -872,31 +872,31 @@ void conv_bias::kern_direct(
} while (0)
switch (FH) {
case 1:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(0)) { GAO(1); }
MIDOUT_END();
break;
case 2:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(1)) { GAO(2); }
MIDOUT_END();
break;
case 3:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(2)) { GAO(3); }
MIDOUT_END();
break;
case 4:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(3)) { GAO(4); }
MIDOUT_END();
break;
case 5:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(4)) { GAO(5); }
MIDOUT_END();
break;
case 6:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(5)) { GAO(6); }
MIDOUT_END();
break;
case 7:
MIDOUT_BEGIN(megdnn_arm_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_BEGIN(megdnn_gi_conv_f32, midout_iv(6)) { GAO(7); }
MIDOUT_END();
break;
}

dnn/src/arm_common/conv_bias/fp32/direct.h → dnn/src/fallback/conv_bias/gi/fp32/direct.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/direct.h
* \file dnn/src/fallback/conv_bias/gi/fp32/direct.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -13,7 +13,7 @@
#include <cstddef>

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace fp32 {
namespace conv_bias {

@@ -23,7 +23,7 @@ void kern_direct(

} // namespace conv_bias
} // namespace fp32
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,12 +11,12 @@
* implied.
*/

#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/general_intrinsic/gi_float.h"
namespace megdnn {
namespace arm_common {
namespace fallback {
namespace conv_bias {
template <>
void pack_src_fp32_nchw44<1>(
@@ -51,23 +51,23 @@ static inline void odd_even_split_iw8_even(
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
GI_FLOAT32_t temp[8];
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[0]);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[2]);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[4]);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[6]);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}

static inline void odd_even_split_iw8_odd(
@@ -77,23 +77,23 @@ static inline void odd_even_split_iw8_odd(
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
GI_FLOAT32_t temp[8];
temp[0] = GiLoadFloat32(sptr + src_offset + 0 * ic_step);
temp[1] = GiLoadFloat32(sptr + src_offset + 1 * ic_step);
temp[2] = GiLoadFloat32(sptr + src_offset + 2 * ic_step);
temp[3] = GiLoadFloat32(sptr + src_offset + 3 * ic_step);
temp[4] = GiLoadFloat32(sptr + src_offset + 4 * ic_step);
temp[5] = GiLoadFloat32(sptr + src_offset + 5 * ic_step);
temp[6] = GiLoadFloat32(sptr + src_offset + 6 * ic_step);
temp[7] = GiLoadFloat32(sptr + src_offset + 7 * ic_step);
GiStoreFloat32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
GiStoreFloat32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
GiStoreFloat32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
GiStoreFloat32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
GiStoreFloat32(sptr_base + even_offset + 0 * ic_step, temp[1]);
GiStoreFloat32(sptr_base + even_offset + 1 * ic_step, temp[3]);
GiStoreFloat32(sptr_base + even_offset + 2 * ic_step, temp[5]);
GiStoreFloat32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace

@@ -104,7 +104,7 @@ void pack_src_fp32_nchw44<2>(
const int pad_top, const int pad_bottom, const int ic, const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
GI_FLOAT32_t zero_v = GiZeroFloat32();
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
@@ -115,9 +115,10 @@ void pack_src_fp32_nchw44<2>(
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v);
GiStoreFloat32(
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v);
}
++iw_idx;
}
@@ -136,21 +137,22 @@ void pack_src_fp32_nchw44<2>(
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(
GiStoreFloat32(
sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
GiLoadFloat32(sptr + src_idx * ic_step));
} else {
vst1q_f32(
GiStoreFloat32(
sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
GiLoadFloat32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
GiStoreFloat32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v);
GiStoreFloat32(
sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v);
}
++iw_idx;
}
@@ -163,7 +165,7 @@ void pack_src_fp32_nchw44<2>(
}

} // namespace conv_bias
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(3);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(5);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_BROADCAST_CHANNEL_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1_NO_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_BROADCAST_CHANNEL_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2_NO_BIAS(7);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,16 +12,15 @@
*/

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/intrinsic_helper.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

template <
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> {
};

#define cb2(step, lane, ow_block) \
c[0][step] = vfmaq_laneq_f32( \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \
c[1][step] = vfmaq_laneq_f32( \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane);

#define cb(step, lane, ow_block) \
c[0][step] = vfmaq_laneq_f32( \
#define cb(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane);

#define SHIFT_CAL_HELPER(ow_block, remain_w) \
@@ -122,7 +121,7 @@ public:
template <
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block,
int ow_block>
struct KerNeonXXs1Nchw44FP32 {
struct KerGiXXs1Nchw44FP32 {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -130,7 +129,7 @@ struct KerNeonXXs1Nchw44FP32 {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -147,20 +146,20 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -172,7 +171,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -189,24 +188,24 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -217,7 +216,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -234,36 +233,36 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -275,7 +274,7 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
struct KerGiXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -292,46 +291,46 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][ic_step];
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[3] = GiLoadFloat32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step);
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[4] = GiLoadFloat32(src_ptr + (ow_block + 4) * ic_step);
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step);
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[5] = GiLoadFloat32(src_ptr + (ow_block + 5) * ic_step);
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -352,10 +351,10 @@ void conv_bias::conv_direct_fp32_nchw44(
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
#if MEGDNN_ARMV7
constexpr int big_oc_step = 4;
#else
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
@@ -381,9 +380,9 @@ void conv_bias::conv_direct_fp32_nchw44(
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs1Nchw44FP32< \
kern_big_oc_remain = KerGiXXs1Nchw44FP32< \
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \
kern_small_oc_remain = KerNeonXXs1Nchw44FP32< \
kern_small_oc_remain = KerGiXXs1Nchw44FP32< \
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \
break;

@@ -402,7 +401,7 @@ void conv_bias::conv_direct_fp32_nchw44(
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
const int bias_offset =
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx;
KerNeonXXs1Nchw44FP32<
KerGiXXs1Nchw44FP32<
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>::
impl(src + src_offset, filter + weight_offset,
bias + bias_offset, dst + dst_offset, ic, ih, iw,
@@ -434,7 +433,7 @@ void conv_bias::conv_direct_fp32_nchw44(
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
const int bias_offset =
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx;
KerNeonXXs1Nchw44FP32<
KerGiXXs1Nchw44FP32<
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>::
impl(src + src_offset, filter + weight_offset,
bias + bias_offset, dst + dst_offset, ic, ih, iw,

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,16 +12,15 @@
*/

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/intrinsic_helper.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

template <
@@ -39,13 +38,13 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, 0, T, T2, T3, T4> {
};

#define cb2(step, lane, ow_block) \
c[0][step] = vfmaq_laneq_f32( \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \
c[1][step] = vfmaq_laneq_f32( \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane);

#define cb(step, lane, ow_block) \
c[0][step] = vfmaq_laneq_f32( \
#define cb(step, lane, ow_block) \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane);

#define SHIFT_CAL_HELPER(ow_block, remain_w) \
@@ -122,7 +121,7 @@ public:
template <
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block,
int ow_block>
struct KerNeonXXs2Nchw44FP32 {
struct KerGiXXs2Nchw44FP32 {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -130,7 +129,7 @@ struct KerNeonXXs2Nchw44FP32 {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -147,36 +146,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;

float32x4_t src[ow_block];
float32x4_t weight[c_dim][4];
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
/////////row 0/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
/////////row 1/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -188,7 +187,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -205,62 +204,62 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic;
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;

float32x4_t src[ow_block];
float32x4_t weight[c_dim][4];
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
/////////row 0/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);

src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
/////////row 1/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh;
//////////row 2/////////////
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);

load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src_ptr += ld_src_iw;
@@ -272,7 +271,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -289,7 +288,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
@@ -297,28 +296,28 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;

for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][4];
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
// even element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
// odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);

@@ -337,7 +336,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
* src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, int ow_block>
struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
struct KerGiXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
static void impl(
const float32_t* src_ptr_origin, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -354,7 +353,7 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const int ld_src_ic = ih * iw;
const int ld_src_iw = iw * oc_step;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][ow_block];
GI_FLOAT32_t c[c_dim][ow_block];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
@@ -362,36 +361,36 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic;

for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) {
float32x4_t src[ow_block];
float32x4_t weight[c_dim][4];
GI_FLOAT32_t src[ow_block];
GI_FLOAT32_t weight[c_dim][4];
// even element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len);
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[2] = GiLoadFloat32(src_ptr + (ow_block + 2) * simd_len);
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight);
// odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
load_helper<ow_block, 0, simd_len, 0, Vld1qF32S>(src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[0] = GiLoadFloat32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight);
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len);
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
src[1] = GiLoadFloat32(src_ptr_odd + (ow_block + 1) * simd_len);
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight);

@@ -414,10 +413,10 @@ void conv_bias::conv_direct_fp32_nchw44(
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
#if MEGDNN_ARMV7
constexpr int big_oc_step = 4;
#else
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
@@ -444,9 +443,9 @@ void conv_bias::conv_direct_fp32_nchw44(
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs2Nchw44FP32< \
kern_big_oc_remain = KerGiXXs2Nchw44FP32< \
bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \
kern_small_oc_remain = KerNeonXXs2Nchw44FP32< \
kern_small_oc_remain = KerGiXXs2Nchw44FP32< \
bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \
break;

@@ -469,7 +468,7 @@ void conv_bias::conv_direct_fp32_nchw44(
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
const int bias_offset =
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx;
KerNeonXXs2Nchw44FP32<
KerGiXXs2Nchw44FP32<
bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>::
impl(src + src_offset, filter + weight_offset,
bias + bias_offset, dst + dst_offset, ic, ih, iw,
@@ -510,7 +509,7 @@ void conv_bias::conv_direct_fp32_nchw44(
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
const int bias_offset =
bias_mode == BiasMode::BIAS ? dst_offset : oc_idx;
KerNeonXXs2Nchw44FP32<
KerGiXXs2Nchw44FP32<
bias_mode, Op, ow_step, filter_size, oc_step, ow_step>::
impl(src + src_offset, filter + weight_offset,
bias + bias_offset, dst + dst_offset, ic, ih, iw,

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(2, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(2, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(2, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(2, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(2, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(3, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1_no_bias
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(3, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(3, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(3, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(3, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(5, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(5, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(5, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(5, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(5, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(7, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(7, 1);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BIAS(7, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_broadcast_channel_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_BROADCAST_CHANNEL_BIAS(7, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp
* dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2_no_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,6 +10,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
#include "src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV_NO_BIAS(7, 2);
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h → dnn/src/fallback/conv_bias/gi/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,20 +11,19 @@
*/
#pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/intrinsic_helper.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"
#if MEGDNN_ARMV7
#include "src/armv7/matrix_mul/asm/common.h"
#endif

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {
/**
@@ -50,15 +49,15 @@ struct ShiftCalHelper<src_idx, weight_idx, c_dim, stride, 0, T, T2, T3> {
};

#define cb(step) \
c[0][step] = vfmaq_laneq_f32( \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4); \
c[1][step] = vfmaq_laneq_f32( \
c[1][step] = GiSimdFmaLane( \
c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

#define cb2(step) \
c[0][step] = vfmaq_laneq_f32( \
c[0][step] = GiSimdFmaLane( \
c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

@@ -127,7 +126,7 @@ public:
template <
BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block,
int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG>
struct KerNeonXXs2NchwNchw44FP32 {
struct KerGiXXs2NchwNchw44FP32 {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -136,8 +135,7 @@ struct KerNeonXXs2NchwNchw44FP32 {
template <
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride,
int ow_block>
struct KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -154,16 +152,16 @@ struct KerNeonXXs2NchwNchw44FP32<
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
GI_FLOAT32_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \
@@ -186,8 +184,7 @@ struct KerNeonXXs2NchwNchw44FP32<
template <
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride,
int ow_block>
struct KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -204,16 +201,16 @@ struct KerNeonXXs2NchwNchw44FP32<
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
GI_FLOAT32_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \
@@ -233,8 +230,7 @@ struct KerNeonXXs2NchwNchw44FP32<
template <
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride,
int ow_block>
struct KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -251,32 +247,32 @@ struct KerNeonXXs2NchwNchw44FP32<
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
GI_FLOAT32_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight);
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight);
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight);
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight);
cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight);

// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight);
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight);
@@ -292,7 +288,7 @@ struct KerNeonXXs2NchwNchw44FP32<
#if MEGDNN_ARMV7

template <BiasMode bias_mode, typename Op>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -310,7 +306,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> {
const int ld_src_ic_skip_bytes =
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[1][8];
GI_FLOAT32_t c[1][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
const int img_stride = ih * iw;
constexpr int filter_stride = filter_size * filter_size * oc_step;
@@ -464,8 +460,7 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::A7_TAG> {
};

template <BiasMode bias_mode, typename Op>
struct KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -483,7 +478,7 @@ struct KerNeonXXs2NchwNchw44FP32<
const int ld_src_ic_skip_bytes =
iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[1][8];
GI_FLOAT32_t c[1][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
/**
* c q8-q15
@@ -626,8 +621,7 @@ struct KerNeonXXs2NchwNchw44FP32<
template <
BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride,
int ow_block>
struct KerNeonXXs2NchwNchw44FP32<
bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> {
struct KerGiXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> {
static void impl(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw,
@@ -644,22 +638,22 @@ struct KerNeonXXs2NchwNchw44FP32<
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
GI_FLOAT32_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
GI_FLOAT32_t src[src_reg_size];
GI_FLOAT32_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight);
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
load_helper<src_reg_size, 0, simd_len, 0, Vld1qF32S>(src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1qF32S>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight);
cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight);
@@ -711,9 +705,9 @@ struct ConvDirectFp32NchwNchw44 {
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = KerNeonXXs2NchwNchw44FP32< \
kern_small_oc_remain = KerGiXXs2NchwNchw44FP32< \
bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \
break;

@@ -731,7 +725,7 @@ struct ConvDirectFp32NchwNchw44 {
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<
KerGiXXs2NchwNchw44FP32<
bias_mode, Op, ow_step, filter_size, big_oc_step, stride,
ow_step>::
impl(src + src_offset, filter + weight_offset,
@@ -760,7 +754,7 @@ struct ConvDirectFp32NchwNchw44 {
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<
KerGiXXs2NchwNchw44FP32<
bias_mode, Op, ow_step, filter_size, oc_step, stride,
ow_step>::
impl(src + src_offset, filter + weight_offset,
@@ -819,7 +813,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> {
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \
kern_big_oc_remain = KerGiXXs2NchwNchw44FP32< \
bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \
break;

@@ -849,7 +843,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> {
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<
KerGiXXs2NchwNchw44FP32<
bias_mode, Op, ow_step, filter_size, big_oc_step,
stride, ow_step, CpuTag::A7_TAG>::
impl(src + src_offset, filter + weight_offset,
@@ -878,7 +872,7 @@ struct ConvDirectFp32NchwNchw44<bias_mode, Op, 3, 2> {
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<
KerGiXXs2NchwNchw44FP32<
bias_mode, Op, ow_step, filter_size, big_oc_step,
stride, ow_step>::
impl(src + src_offset, filter + weight_offset,

+ 723
- 0
dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp View File

@@ -0,0 +1,723 @@
/**
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <algorithm>

#include "src/fallback/conv_bias/gi/fp32/do_conv_stride1.h"
#include "src/fallback/conv_bias/gi/postprocess_helper.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/general_intrinsic/gi_float.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs1)

using namespace megdnn;
using namespace fallback;
using namespace fp32;
using namespace conv_stride1;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;

void conv_stride1::do_conv_2x2_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;
//! unroll of 2
size_t ic = 0;
for (; ic + 1 < IC; ic += 2) {
const float* src_ptr = src + IW * IH * ic;
const float* src_ptr1 = src_ptr + IW * IH;
float* outptr = dst;

const float* r00 = src_ptr;
const float* r01 = src_ptr + IW;
const float* r10 = src_ptr1;
const float* r11 = src_ptr1 + IW;

const float* k0 = filter + ic * 4;
const float* k1 = k0 + 4;

GI_FLOAT32_t _k0 = GiLoadFloat32(k0);
GI_FLOAT32_t _k1 = GiLoadFloat32(k1);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _r000 = GiLoadFloat32(r00);
GI_FLOAT32_t _r010 = GiLoadFloat32(r01);
GI_FLOAT32_t _r001 = GiLoadFloat32(r00 + 1);
GI_FLOAT32_t _r011 = GiLoadFloat32(r01 + 1);

GI_FLOAT32_t _r100 = GiLoadFloat32(r10);
GI_FLOAT32_t _r110 = GiLoadFloat32(r11);
GI_FLOAT32_t _r101 = GiLoadFloat32(r10 + 1);
GI_FLOAT32_t _r111 = GiLoadFloat32(r11 + 1);

GI_FLOAT32_t _sum = GiLoadFloat32(outptr);

_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r000, _k0, 0);
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r001, _k0, 1);
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r010, _k0, 0);
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r011, _k0, 1);

_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r100, _k1, 0);
_sum = GiVmlaqLaneFloat32LowHalf(_sum, _r101, _k1, 1);
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r110, _k1, 0);
_sum = GiMlaqLaneFloat32HighHalf(_sum, _r111, _k1, 1);

GiStoreFloat32(outptr, _sum);

r00 += 4;
r01 += 4;
r10 += 4;
r11 += 4;
outptr += 4;
}

r00 += tail_step;
r01 += tail_step;
r10 += tail_step;
r11 += tail_step;
}
}
for (; ic < IC; ic++) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter + ic * 4;

GI_FLOAT32_t _k0 = GiBroadcastFloat32(k0[0]);
GI_FLOAT32_t _k1 = GiBroadcastFloat32(k0[1]);
GI_FLOAT32_t _k2 = GiBroadcastFloat32(k0[2]);
GI_FLOAT32_t _k3 = GiBroadcastFloat32(k0[3]);
rep(h, OH) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _r00 = GiLoadFloat32(r0);
GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r01 = GiLoadFloat32(r0 + 1);
GI_FLOAT32_t _r11 = GiLoadFloat32(r1 + 1);

GI_FLOAT32_t _sum = GiLoadFloat32(outptr);
GI_FLOAT32_t _sum2;

_sum = GiMlaqFloat32(_sum, _r00, _k0);
_sum2 = GiMultiplyFloat32(_r01, _k1);
_sum = GiMlaqFloat32(_sum, _r10, _k2);
_sum2 = GiMlaqFloat32(_sum2, _r11, _k3);

_sum = GiAddFloat32(_sum, _sum2);

GiStoreFloat32(outptr, _sum);

r0 += 4;
r1 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}
}
}

void conv_stride1::do_conv_3x3_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

GI_FLOAT32_t _k0123 = GiLoadFloat32(k0);
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1);
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2);
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr);
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f);
GI_FLOAT32_t _sum3 = GiLoadFloat32(outptr2);
GI_FLOAT32_t _sum4 = GiBroadcastFloat32(0.f);

GI_FLOAT32_t _r00 = GiLoadFloat32(r0);
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4);
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1);
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2);

GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4);
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2);

GI_FLOAT32_t _r20 = GiLoadFloat32(r2);
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4);
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2);

GI_FLOAT32_t _r30 = GiLoadFloat32(r3);
GI_FLOAT32_t _r30n = GiLoadFloat32LowHalf(r3 + 4);
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r30n, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r30n, 2);

_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0);
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1);
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2);
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0);
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1);
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2);
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0);
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1);
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2);

_sum3 = GiSimdFmaLane(_sum3, _r10, _k0123, 0);
_sum4 = GiSimdFmaLane(_sum4, _r11, _k0123, 1);
_sum3 = GiSimdFmaLane(_sum3, _r12, _k0123, 2);
_sum4 = GiSimdFmaLane(_sum4, _r20, _k3456, 0);
_sum3 = GiSimdFmaLane(_sum3, _r21, _k3456, 1);
_sum4 = GiSimdFmaLane(_sum4, _r22, _k3456, 2);
_sum3 = GiSimdFmaLane(_sum3, _r30, _k6789, 0);
_sum4 = GiSimdFmaLane(_sum4, _r31, _k6789, 1);
_sum3 = GiSimdFmaLane(_sum3, _r32, _k6789, 2);

_sum1 = GiAddFloat32(_sum1, _sum2);
_sum3 = GiAddFloat32(_sum3, _sum4);

GiStoreFloat32(outptr, _sum1);
GiStoreFloat32(outptr2, _sum3);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _sum1 = GiLoadFloat32(outptr);
GI_FLOAT32_t _sum2 = GiBroadcastFloat32(0.f);

GI_FLOAT32_t _r00 = GiLoadFloat32(r0);
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 4);
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r00n, 1);
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r00n, 2);

GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 4);
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r10n, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r10n, 2);

GI_FLOAT32_t _r20 = GiLoadFloat32(r2);
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 4);
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r20n, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r20n, 2);

_sum1 = GiSimdFmaLane(_sum1, _r00, _k0123, 0);
_sum2 = GiSimdFmaLane(_sum2, _r01, _k0123, 1);
_sum1 = GiSimdFmaLane(_sum1, _r02, _k0123, 2);
_sum2 = GiSimdFmaLane(_sum2, _r10, _k3456, 0);
_sum1 = GiSimdFmaLane(_sum1, _r11, _k3456, 1);
_sum2 = GiSimdFmaLane(_sum2, _r12, _k3456, 2);
_sum1 = GiSimdFmaLane(_sum1, _r20, _k6789, 0);
_sum2 = GiSimdFmaLane(_sum2, _r21, _k6789, 1);
_sum1 = GiSimdFmaLane(_sum1, _r22, _k6789, 2);

_sum1 = GiAddFloat32(_sum1, _sum2);

GiStoreFloat32(outptr, _sum1);

r0 += 4;
r1 += 4;
r2 += 4;
outptr += 4;
}
r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride1::do_conv_5x5_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;
float* outptr2 = outptr + OW;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;

GI_FLOAT32_t _k0123 = GiLoadFloat32(filter);
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4);
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8);
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12);
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16);
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20);
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]);

size_t h = 0;
for (; h + 1 < OH; h += 2) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _sum = GiLoadFloat32(outptr);
GI_FLOAT32_t _sum2 = GiLoadFloat32(outptr2);

GI_FLOAT32_t _r00 = GiLoadFloat32(r0);
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4);
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1);
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2);
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3);

GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4);
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3);

GI_FLOAT32_t _r20 = GiLoadFloat32(r2);
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4);
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3);

GI_FLOAT32_t _r30 = GiLoadFloat32(r3);
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4);
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3);

GI_FLOAT32_t _r40 = GiLoadFloat32(r4);
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4);
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1);
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3);

GI_FLOAT32_t _r50 = GiLoadFloat32(r5);
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4);
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1);
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2);
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3);

_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0);
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1);
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2);
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3);
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0);

_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1);
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2);
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3);
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0);
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1);

_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2);
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3);
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0);
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1);
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2);

_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3);
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0);
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1);
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2);
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3);

_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0);
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1);
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2);
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3);
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0);

_sum2 = GiSimdFmaLane(_sum2, _r10, _k0123, 0);
_sum2 = GiSimdFmaLane(_sum2, _r11, _k0123, 1);
_sum2 = GiSimdFmaLane(_sum2, _r12, _k0123, 2);
_sum2 = GiSimdFmaLane(_sum2, _r13, _k0123, 3);
_sum2 = GiSimdFmaLane(_sum2, _r14, _k4567, 0);

_sum2 = GiSimdFmaLane(_sum2, _r20, _k4567, 1);
_sum2 = GiSimdFmaLane(_sum2, _r21, _k4567, 2);
_sum2 = GiSimdFmaLane(_sum2, _r22, _k4567, 3);
_sum2 = GiSimdFmaLane(_sum2, _r23, _k891011, 0);
_sum2 = GiSimdFmaLane(_sum2, _r24, _k891011, 1);

_sum2 = GiSimdFmaLane(_sum2, _r30, _k891011, 2);
_sum2 = GiSimdFmaLane(_sum2, _r31, _k891011, 3);
_sum2 = GiSimdFmaLane(_sum2, _r32, _k12131415, 0);
_sum2 = GiSimdFmaLane(_sum2, _r33, _k12131415, 1);
_sum2 = GiSimdFmaLane(_sum2, _r34, _k12131415, 2);

_sum2 = GiSimdFmaLane(_sum2, _r40, _k12131415, 3);
_sum2 = GiSimdFmaLane(_sum2, _r41, _k16171819, 0);
_sum2 = GiSimdFmaLane(_sum2, _r42, _k16171819, 1);
_sum2 = GiSimdFmaLane(_sum2, _r43, _k16171819, 2);
_sum2 = GiSimdFmaLane(_sum2, _r44, _k16171819, 3);

_sum2 = GiSimdFmaLane(_sum2, _r50, _k20212223, 0);
_sum2 = GiSimdFmaLane(_sum2, _r51, _k20212223, 1);
_sum2 = GiSimdFmaLane(_sum2, _r52, _k20212223, 2);
_sum2 = GiSimdFmaLane(_sum2, _r53, _k20212223, 3);
_sum2 = GiSimdFmaLane(_sum2, _r54, _k24242424, 0);

GiStoreFloat32(outptr, _sum);
GiStoreFloat32(outptr2, _sum2);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
outptr += 4;
outptr2 += 4;
}

r0 += tail_step + IW;
r1 += tail_step + IW;
r2 += tail_step + IW;
r3 += tail_step + IW;
r4 += tail_step + IW;
r5 += tail_step + IW;

outptr += OW;
outptr2 += OW;
}

for (; h < OH; h++) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _sum = GiLoadFloat32(outptr);

GI_FLOAT32_t _r00 = GiLoadFloat32(r0);
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4);
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1);
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2);
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3);

GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4);
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3);

GI_FLOAT32_t _r20 = GiLoadFloat32(r2);
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4);
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3);

GI_FLOAT32_t _r30 = GiLoadFloat32(r3);
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4);
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3);

GI_FLOAT32_t _r40 = GiLoadFloat32(r4);
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4);
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1);
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3);

_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0);
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1);
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2);
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3);
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0);

_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1);
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2);
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3);
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0);
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1);

_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2);
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3);
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0);
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1);
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2);

_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3);
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0);
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1);
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2);
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3);

_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0);
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1);
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2);
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3);
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0);

GiStoreFloat32(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride1::do_conv_7x7_stride1(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - OW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int width = OW >> 2;

rep(i, width) {
GI_FLOAT32_t _sum = GiLoadFloat32(outptr);

GI_FLOAT32_t _k0123 = GiLoadFloat32(k0);
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4);

GI_FLOAT32_t _r00 = GiLoadFloat32(r0); // 0 1 2 3
GI_FLOAT32_t _r04 = GiLoadFloat32(r0 + 4); // 4 5 6 7
GI_FLOAT32_t _r00n = GiLoadFloat32(r0 + 8); // 8 9 10 11
GI_FLOAT32_t _r01 = GiExtqFloat32(_r00, _r04, 1); // 1 2 3 4
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r04, 2); // 2 3 4 5
GI_FLOAT32_t _r03 = GiExtqFloat32(_r00, _r04, 3); // 3 4 5 6
GI_FLOAT32_t _r05 = GiExtqFloat32(_r04, _r00n, 1); // 5 6 7 8
GI_FLOAT32_t _r06 = GiExtqFloat32(_r04, _r00n, 2); // 6 7 8 9

_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0);
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1);
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2);
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3);
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0);
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1);
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2);

GI_FLOAT32_t _k78910 = GiLoadFloat32(k1);
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4);

GI_FLOAT32_t _r10 = GiLoadFloat32(r1);
GI_FLOAT32_t _r14 = GiLoadFloat32(r1 + 4);
GI_FLOAT32_t _r10n = GiLoadFloat32(r1 + 8);
GI_FLOAT32_t _r11 = GiExtqFloat32(_r10, _r14, 1);
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r14, 2);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r10, _r14, 3);
GI_FLOAT32_t _r15 = GiExtqFloat32(_r14, _r10n, 1);
GI_FLOAT32_t _r16 = GiExtqFloat32(_r14, _r10n, 2);

_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0);
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1);
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2);
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3);
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0);
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1);
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2);

GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2);
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4);

GI_FLOAT32_t _r20 = GiLoadFloat32(r2);
GI_FLOAT32_t _r24 = GiLoadFloat32(r2 + 4);
GI_FLOAT32_t _r20n = GiLoadFloat32(r2 + 8);
GI_FLOAT32_t _r21 = GiExtqFloat32(_r20, _r24, 1);
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r24, 2);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r20, _r24, 3);
GI_FLOAT32_t _r25 = GiExtqFloat32(_r24, _r20n, 1);
GI_FLOAT32_t _r26 = GiExtqFloat32(_r24, _r20n, 2);

_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0);
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1);
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2);
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3);
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0);
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1);
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2);

GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3);
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4);

GI_FLOAT32_t _r30 = GiLoadFloat32(r3);
GI_FLOAT32_t _r34 = GiLoadFloat32(r3 + 4);
GI_FLOAT32_t _r30n = GiLoadFloat32(r3 + 8);
GI_FLOAT32_t _r31 = GiExtqFloat32(_r30, _r34, 1);
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r34, 2);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r30, _r34, 3);
GI_FLOAT32_t _r35 = GiExtqFloat32(_r34, _r30n, 1);
GI_FLOAT32_t _r36 = GiExtqFloat32(_r34, _r30n, 2);

_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0);
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1);
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2);
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3);
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0);
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1);
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2);

GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4);
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4);

GI_FLOAT32_t _r40 = GiLoadFloat32(r4);
GI_FLOAT32_t _r44 = GiLoadFloat32(r4 + 4);
GI_FLOAT32_t _r40n = GiLoadFloat32(r4 + 8);
GI_FLOAT32_t _r41 = GiExtqFloat32(_r40, _r44, 1);
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r44, 2);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r40, _r44, 3);
GI_FLOAT32_t _r45 = GiExtqFloat32(_r44, _r40n, 1);
GI_FLOAT32_t _r46 = GiExtqFloat32(_r44, _r40n, 2);

_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0);
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1);
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2);
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3);
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0);
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1);
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2);

GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5);
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4);

GI_FLOAT32_t _r50 = GiLoadFloat32(r5);
GI_FLOAT32_t _r54 = GiLoadFloat32(r5 + 4);
GI_FLOAT32_t _r50n = GiLoadFloat32(r5 + 8);
GI_FLOAT32_t _r51 = GiExtqFloat32(_r50, _r54, 1);
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r54, 2);
GI_FLOAT32_t _r53 = GiExtqFloat32(_r50, _r54, 3);
GI_FLOAT32_t _r55 = GiExtqFloat32(_r54, _r50n, 1);
GI_FLOAT32_t _r56 = GiExtqFloat32(_r54, _r50n, 2);

_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0);
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1);
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2);
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3);
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0);
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1);
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2);

GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6);
GI_FLOAT32_t _k46474849 =
GiLd1qLaneFloat32(k6 + 4 + 2, GiLoadFloat32LowHalf(k6 + 4), 2);

GI_FLOAT32_t _r60 = GiLoadFloat32(r6);
GI_FLOAT32_t _r64 = GiLoadFloat32(r6 + 4);
GI_FLOAT32_t _r60n = GiLoadFloat32(r6 + 8);
GI_FLOAT32_t _r61 = GiExtqFloat32(_r60, _r64, 1);
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r64, 2);
GI_FLOAT32_t _r63 = GiExtqFloat32(_r60, _r64, 3);
GI_FLOAT32_t _r65 = GiExtqFloat32(_r64, _r60n, 1);
GI_FLOAT32_t _r66 = GiExtqFloat32(_r64, _r60n, 2);

_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0);
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1);
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2);
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3);
_sum = GiSimdFmaLane(_sum, _r64, _k46474849, 0);
_sum = GiSimdFmaLane(_sum, _r65, _k46474849, 1);
_sum = GiSimdFmaLane(_sum, _r66, _k46474849, 2);

GiStoreFloat32(outptr, _sum);

r0 += 4;
r1 += 4;
r2 += 4;
r3 += 4;
r4 += 4;
r5 += 4;
r6 += 4;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}

#include "src/common/simd_macro/epilogue.h"
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h → dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -13,7 +13,7 @@
#include <cstddef>

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace fp32 {
namespace conv_stride1 {

@@ -31,7 +31,7 @@ void do_conv_7x7_stride1(
size_t OH, size_t OW, size_t IC);
} // namespace conv_stride1
} // namespace fp32
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 503
- 0
dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp View File

@@ -0,0 +1,503 @@
/**
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include <algorithm>

#include "./do_conv_stride2.h"
#include "midout.h"
#include "src/fallback/conv_bias/gi/postprocess_helper.h"
#include "src/fallback/general_intrinsic/gi_float.h"

MIDOUT_DECL(megdnn_fallback_conv_bias_f32_convs2)

using namespace megdnn;
using namespace fallback;
using namespace fp32;
using namespace conv_stride2;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;

void conv_stride2::do_conv_2x2_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;

const float* k0 = filter;

GI_FLOAT32_t _k0123 = GiLoadFloat32(k0);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
GI_FLOAT32_t _outp = GiLoadFloat32(outptr);

GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0);

GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7

_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0);
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1);

GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1);

GI_FLOAT32_t _r10 = _r1.val[0];
GI_FLOAT32_t _r11 = _r1.val[1];

_outp = GiSimdFmaLane(_outp, _r10, _k0123, 2);
_outp = GiSimdFmaLane(_outp, _r11, _k0123, 3);

GiStoreFloat32(outptr, _outp);

r0 += 8;
r1 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
}

filter += 4;
}
}

void conv_stride2::do_conv_3x3_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;

const float* k0 = filter;
const float* k1 = filter + 3;
const float* k2 = filter + 5;

GI_FLOAT32_t _k0123 = GiLoadFloat32(k0);
GI_FLOAT32_t _k3456 = GiLoadFloat32(k1);
GI_FLOAT32_t _k5678 = GiLoadFloat32(k2);
GI_FLOAT32_t _k6789 = GiExtqFloat32(_k5678, _k5678, 1);
rep(h, OH) {
int nn = OW >> 2;

rep(i, nn) {
GI_FLOAT32_t _outp = GiLoadFloat32(outptr);

GI_FLOAT32_V2_t _r0 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r0n = GiLd2qFloat32(r0 + 8);

GI_FLOAT32_t _r00 = _r0.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r0.val[1]; // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0n.val[0], 1); // 2 4 6 8

_outp = GiSimdFmaLane(_outp, _r00, _k0123, 0);
_outp = GiSimdFmaLane(_outp, _r01, _k0123, 1);
_outp = GiSimdFmaLane(_outp, _r02, _k0123, 2);

GI_FLOAT32_V2_t _r1 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r1n = GiLd2qFloat32(r1 + 8);

GI_FLOAT32_t _r10 = _r1.val[0];
GI_FLOAT32_t _r11 = _r1.val[1];
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1n.val[0], 1);

_outp = GiSimdFmaLane(_outp, _r10, _k3456, 0);
_outp = GiSimdFmaLane(_outp, _r11, _k3456, 1);
_outp = GiSimdFmaLane(_outp, _r12, _k3456, 2);

GI_FLOAT32_V2_t _r2 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r2n = GiLd2qFloat32(r2 + 8);

GI_FLOAT32_t _r20 = _r2.val[0];
GI_FLOAT32_t _r21 = _r2.val[1];
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2n.val[0], 1);

_outp = GiSimdFmaLane(_outp, _r20, _k6789, 0);
_outp = GiSimdFmaLane(_outp, _r21, _k6789, 1);
_outp = GiSimdFmaLane(_outp, _r22, _k6789, 2);

GiStoreFloat32(outptr, _outp);

r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
}

filter += 9;
}
}

void conv_stride2::do_conv_5x5_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;

GI_FLOAT32_t _k0123 = GiLoadFloat32(filter);
GI_FLOAT32_t _k4567 = GiLoadFloat32(filter + 4);
GI_FLOAT32_t _k891011 = GiLoadFloat32(filter + 8);
GI_FLOAT32_t _k12131415 = GiLoadFloat32(filter + 12);
GI_FLOAT32_t _k16171819 = GiLoadFloat32(filter + 16);
GI_FLOAT32_t _k20212223 = GiLoadFloat32(filter + 20);
GI_FLOAT32_t _k24242424 = GiBroadcastFloat32(filter[24]);

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
GI_FLOAT32_t _sum = GiLoadFloat32(outptr);

GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8);
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10

GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8);
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0];
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1];
GI_FLOAT32_t _r10 = _r10_02461357.val[0];
GI_FLOAT32_t _r11 = _r10_02461357.val[1];
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1);
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2);

GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8);
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0];
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1];
GI_FLOAT32_t _r20 = _r20_02461357.val[0];
GI_FLOAT32_t _r21 = _r20_02461357.val[1];
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1);
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2);

GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3);
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8);
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0];
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1];
GI_FLOAT32_t _r30 = _r30_02461357.val[0];
GI_FLOAT32_t _r31 = _r30_02461357.val[1];
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1);
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2);

GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4);
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8);
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0];
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1];
GI_FLOAT32_t _r40 = _r40_02461357.val[0];
GI_FLOAT32_t _r41 = _r40_02461357.val[1];
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1);
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2);

_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0);
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1);
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2);
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3);
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0);

_sum = GiSimdFmaLane(_sum, _r10, _k4567, 1);
_sum = GiSimdFmaLane(_sum, _r11, _k4567, 2);
_sum = GiSimdFmaLane(_sum, _r12, _k4567, 3);
_sum = GiSimdFmaLane(_sum, _r13, _k891011, 0);
_sum = GiSimdFmaLane(_sum, _r14, _k891011, 1);

_sum = GiSimdFmaLane(_sum, _r20, _k891011, 2);
_sum = GiSimdFmaLane(_sum, _r21, _k891011, 3);
_sum = GiSimdFmaLane(_sum, _r22, _k12131415, 0);
_sum = GiSimdFmaLane(_sum, _r23, _k12131415, 1);
_sum = GiSimdFmaLane(_sum, _r24, _k12131415, 2);

_sum = GiSimdFmaLane(_sum, _r30, _k12131415, 3);
_sum = GiSimdFmaLane(_sum, _r31, _k16171819, 0);
_sum = GiSimdFmaLane(_sum, _r32, _k16171819, 1);
_sum = GiSimdFmaLane(_sum, _r33, _k16171819, 2);
_sum = GiSimdFmaLane(_sum, _r34, _k16171819, 3);

_sum = GiSimdFmaLane(_sum, _r40, _k20212223, 0);
_sum = GiSimdFmaLane(_sum, _r41, _k20212223, 1);
_sum = GiSimdFmaLane(_sum, _r42, _k20212223, 2);
_sum = GiSimdFmaLane(_sum, _r43, _k20212223, 3);
_sum = GiSimdFmaLane(_sum, _r44, _k24242424, 0);

GiStoreFloat32(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
}

filter += 25;
}
}

void conv_stride2::do_conv_7x7_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;

rep(ic, IC) {
const float* src_ptr = src + IW * IH * ic;
float* outptr = dst;

const float* r0 = src_ptr;
const float* r1 = src_ptr + IW;
const float* r2 = src_ptr + IW * 2;
const float* r3 = src_ptr + IW * 3;
const float* r4 = src_ptr + IW * 4;
const float* r5 = src_ptr + IW * 5;
const float* r6 = src_ptr + IW * 6;

const float* k0 = filter;
const float* k1 = filter + 7;
const float* k2 = filter + 14;
const float* k3 = filter + 21;
const float* k4 = filter + 28;
const float* k5 = filter + 35;
const float* k6 = filter + 42;

for (size_t i = 0; i < OH; i++) {
int nn = OW >> 2;

rep(i, nn) {
GI_FLOAT32_t _sum = GiLoadFloat32(outptr);

GI_FLOAT32_t _k0123 = GiLoadFloat32(k0);
GI_FLOAT32_t _k4567 = GiLoadFloat32(k0 + 4);

GI_FLOAT32_V2_t _r00_02461357 = GiLd2qFloat32(r0);
GI_FLOAT32_V2_t _r00nx2 = GiLd2qFloat32(r0 + 8);
GI_FLOAT32_t _r0_8101214 = _r00nx2.val[0]; // 8 10 12 14
GI_FLOAT32_t _r0_9111315 = _r00nx2.val[1]; // 9 11 13 15
GI_FLOAT32_t _r00 = _r00_02461357.val[0]; // 0 2 4 6
GI_FLOAT32_t _r01 = _r00_02461357.val[1]; // 1 3 5 7
GI_FLOAT32_t _r02 = GiExtqFloat32(_r00, _r0_8101214, 1); // 2 4 6 8
GI_FLOAT32_t _r03 = GiExtqFloat32(_r01, _r0_9111315, 1); // 3 5 7 9
GI_FLOAT32_t _r04 = GiExtqFloat32(_r00, _r0_8101214, 2); // 4 6 8 10
GI_FLOAT32_t _r05 = GiExtqFloat32(_r01, _r0_9111315, 2); // 5 7 9 11
GI_FLOAT32_t _r06 = GiExtqFloat32(_r00, _r0_8101214, 3); // 6 8 10 12

_sum = GiSimdFmaLane(_sum, _r00, _k0123, 0);
_sum = GiSimdFmaLane(_sum, _r01, _k0123, 1);
_sum = GiSimdFmaLane(_sum, _r02, _k0123, 2);
_sum = GiSimdFmaLane(_sum, _r03, _k0123, 3);
_sum = GiSimdFmaLane(_sum, _r04, _k4567, 0);
_sum = GiSimdFmaLane(_sum, _r05, _k4567, 1);
_sum = GiSimdFmaLane(_sum, _r06, _k4567, 2);

GI_FLOAT32_t _k78910 = GiLoadFloat32(k1);
GI_FLOAT32_t _k11121314 = GiLoadFloat32(k1 + 4);

GI_FLOAT32_V2_t _r10_02461357 = GiLd2qFloat32(r1);
GI_FLOAT32_V2_t _r10nx2 = GiLd2qFloat32(r1 + 8);
GI_FLOAT32_t _r1_8101214 = _r10nx2.val[0];
GI_FLOAT32_t _r1_9111315 = _r10nx2.val[1];
GI_FLOAT32_t _r10 = _r10_02461357.val[0];
GI_FLOAT32_t _r11 = _r10_02461357.val[1];
GI_FLOAT32_t _r12 = GiExtqFloat32(_r10, _r1_8101214, 1);
GI_FLOAT32_t _r13 = GiExtqFloat32(_r11, _r1_9111315, 1);
GI_FLOAT32_t _r14 = GiExtqFloat32(_r10, _r1_8101214, 2);
GI_FLOAT32_t _r15 = GiExtqFloat32(_r11, _r1_9111315, 2);
GI_FLOAT32_t _r16 = GiExtqFloat32(_r10, _r1_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r10, _k78910, 0);
_sum = GiSimdFmaLane(_sum, _r11, _k78910, 1);
_sum = GiSimdFmaLane(_sum, _r12, _k78910, 2);
_sum = GiSimdFmaLane(_sum, _r13, _k78910, 3);
_sum = GiSimdFmaLane(_sum, _r14, _k11121314, 0);
_sum = GiSimdFmaLane(_sum, _r15, _k11121314, 1);
_sum = GiSimdFmaLane(_sum, _r16, _k11121314, 2);

GI_FLOAT32_t _k14151617 = GiLoadFloat32(k2);
GI_FLOAT32_t _k18192021 = GiLoadFloat32(k2 + 4);

GI_FLOAT32_V2_t _r20_02461357 = GiLd2qFloat32(r2);
GI_FLOAT32_V2_t _r20nx2 = GiLd2qFloat32(r2 + 8);
GI_FLOAT32_t _r2_8101214 = _r20nx2.val[0];
GI_FLOAT32_t _r2_9111315 = _r20nx2.val[1];
GI_FLOAT32_t _r20 = _r20_02461357.val[0];
GI_FLOAT32_t _r21 = _r20_02461357.val[1];
GI_FLOAT32_t _r22 = GiExtqFloat32(_r20, _r2_8101214, 1);
GI_FLOAT32_t _r23 = GiExtqFloat32(_r21, _r2_9111315, 1);
GI_FLOAT32_t _r24 = GiExtqFloat32(_r20, _r2_8101214, 2);
GI_FLOAT32_t _r25 = GiExtqFloat32(_r21, _r2_9111315, 2);
GI_FLOAT32_t _r26 = GiExtqFloat32(_r20, _r2_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r20, _k14151617, 0);
_sum = GiSimdFmaLane(_sum, _r21, _k14151617, 1);
_sum = GiSimdFmaLane(_sum, _r22, _k14151617, 2);
_sum = GiSimdFmaLane(_sum, _r23, _k14151617, 3);
_sum = GiSimdFmaLane(_sum, _r24, _k18192021, 0);
_sum = GiSimdFmaLane(_sum, _r25, _k18192021, 1);
_sum = GiSimdFmaLane(_sum, _r26, _k18192021, 2);

GI_FLOAT32_t _k21222324 = GiLoadFloat32(k3);
GI_FLOAT32_t _k25262728 = GiLoadFloat32(k3 + 4);

GI_FLOAT32_V2_t _r30_02461357 = GiLd2qFloat32(r3);
GI_FLOAT32_V2_t _r30nx2 = GiLd2qFloat32(r3 + 8);
GI_FLOAT32_t _r3_8101214 = _r30nx2.val[0];
GI_FLOAT32_t _r3_9111315 = _r30nx2.val[1];
GI_FLOAT32_t _r30 = _r30_02461357.val[0];
GI_FLOAT32_t _r31 = _r30_02461357.val[1];
GI_FLOAT32_t _r32 = GiExtqFloat32(_r30, _r3_8101214, 1);
GI_FLOAT32_t _r33 = GiExtqFloat32(_r31, _r3_9111315, 1);
GI_FLOAT32_t _r34 = GiExtqFloat32(_r30, _r3_8101214, 2);
GI_FLOAT32_t _r35 = GiExtqFloat32(_r31, _r3_9111315, 2);
GI_FLOAT32_t _r36 = GiExtqFloat32(_r30, _r3_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r30, _k21222324, 0);
_sum = GiSimdFmaLane(_sum, _r31, _k21222324, 1);
_sum = GiSimdFmaLane(_sum, _r32, _k21222324, 2);
_sum = GiSimdFmaLane(_sum, _r33, _k21222324, 3);
_sum = GiSimdFmaLane(_sum, _r34, _k25262728, 0);
_sum = GiSimdFmaLane(_sum, _r35, _k25262728, 1);
_sum = GiSimdFmaLane(_sum, _r36, _k25262728, 2);

GI_FLOAT32_t _k28293031 = GiLoadFloat32(k4);
GI_FLOAT32_t _k32333435 = GiLoadFloat32(k4 + 4);

GI_FLOAT32_V2_t _r40_02461357 = GiLd2qFloat32(r4);
GI_FLOAT32_V2_t _r40nx2 = GiLd2qFloat32(r4 + 8);
GI_FLOAT32_t _r4_8101214 = _r40nx2.val[0];
GI_FLOAT32_t _r4_9111315 = _r40nx2.val[1];
GI_FLOAT32_t _r40 = _r40_02461357.val[0];
GI_FLOAT32_t _r41 = _r40_02461357.val[1];
GI_FLOAT32_t _r42 = GiExtqFloat32(_r40, _r4_8101214, 1);
GI_FLOAT32_t _r43 = GiExtqFloat32(_r41, _r4_9111315, 1);
GI_FLOAT32_t _r44 = GiExtqFloat32(_r40, _r4_8101214, 2);
GI_FLOAT32_t _r45 = GiExtqFloat32(_r41, _r4_9111315, 2);
GI_FLOAT32_t _r46 = GiExtqFloat32(_r40, _r4_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r40, _k28293031, 0);
_sum = GiSimdFmaLane(_sum, _r41, _k28293031, 1);
_sum = GiSimdFmaLane(_sum, _r42, _k28293031, 2);
_sum = GiSimdFmaLane(_sum, _r43, _k28293031, 3);
_sum = GiSimdFmaLane(_sum, _r44, _k32333435, 0);
_sum = GiSimdFmaLane(_sum, _r45, _k32333435, 1);
_sum = GiSimdFmaLane(_sum, _r46, _k32333435, 2);

GI_FLOAT32_t _k35363738 = GiLoadFloat32(k5);
GI_FLOAT32_t _k39404142 = GiLoadFloat32(k5 + 4);

GI_FLOAT32_V2_t _r50_02461357 = GiLd2qFloat32(r5);
GI_FLOAT32_V2_t _r50nx2 = GiLd2qFloat32(r5 + 8);
GI_FLOAT32_t _r5_8101214 = _r50nx2.val[0];
GI_FLOAT32_t _r5_9111315 = _r50nx2.val[1];
GI_FLOAT32_t _r50 = _r50_02461357.val[0];
GI_FLOAT32_t _r51 = _r50_02461357.val[1];
GI_FLOAT32_t _r52 = GiExtqFloat32(_r50, _r5_8101214, 1);
GI_FLOAT32_t _r53 = GiExtqFloat32(_r51, _r5_9111315, 1);
GI_FLOAT32_t _r54 = GiExtqFloat32(_r50, _r5_8101214, 2);
GI_FLOAT32_t _r55 = GiExtqFloat32(_r51, _r5_9111315, 2);
GI_FLOAT32_t _r56 = GiExtqFloat32(_r50, _r5_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r50, _k35363738, 0);
_sum = GiSimdFmaLane(_sum, _r51, _k35363738, 1);
_sum = GiSimdFmaLane(_sum, _r52, _k35363738, 2);
_sum = GiSimdFmaLane(_sum, _r53, _k35363738, 3);
_sum = GiSimdFmaLane(_sum, _r54, _k39404142, 0);
_sum = GiSimdFmaLane(_sum, _r55, _k39404142, 1);
_sum = GiSimdFmaLane(_sum, _r56, _k39404142, 2);

GI_FLOAT32_t _k42434445 = GiLoadFloat32(k6);
GI_FLOAT32_t _k45464748 = GiLoadFloat32(k6 + 3);

GI_FLOAT32_V2_t _r60_02461357 = GiLd2qFloat32(r6);
GI_FLOAT32_V2_t _r60nx2 = GiLd2qFloat32(r6 + 8);
GI_FLOAT32_t _r6_8101214 = _r60nx2.val[0];
GI_FLOAT32_t _r6_9111315 = _r60nx2.val[1];
GI_FLOAT32_t _r60 = _r60_02461357.val[0];
GI_FLOAT32_t _r61 = _r60_02461357.val[1];
GI_FLOAT32_t _r62 = GiExtqFloat32(_r60, _r6_8101214, 1);
GI_FLOAT32_t _r63 = GiExtqFloat32(_r61, _r6_9111315, 1);
GI_FLOAT32_t _r64 = GiExtqFloat32(_r60, _r6_8101214, 2);
GI_FLOAT32_t _r65 = GiExtqFloat32(_r61, _r6_9111315, 2);
GI_FLOAT32_t _r66 = GiExtqFloat32(_r60, _r6_8101214, 3);

_sum = GiSimdFmaLane(_sum, _r60, _k42434445, 0);
_sum = GiSimdFmaLane(_sum, _r61, _k42434445, 1);
_sum = GiSimdFmaLane(_sum, _r62, _k42434445, 2);
_sum = GiSimdFmaLane(_sum, _r63, _k42434445, 3);
_sum = GiSimdFmaLane(_sum, _r64, _k45464748, 1);
_sum = GiSimdFmaLane(_sum, _r65, _k45464748, 2);
_sum = GiSimdFmaLane(_sum, _r66, _k45464748, 3);

GiStoreFloat32(outptr, _sum);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
r4 += 8;
r5 += 8;
r6 += 8;
outptr += 4;
}

r0 += tail_step;
r1 += tail_step;
r2 += tail_step;
r3 += tail_step;
r4 += tail_step;
r5 += tail_step;
r6 += tail_step;
}
filter += 49;
}
}
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h → dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h
* \file dnn/src/fallback/conv_bias/gi/fp32/do_conv_stride2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -13,7 +13,7 @@
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace fp32 {
namespace conv_stride2 {
void do_conv_2x2_stride2(
@@ -30,7 +30,7 @@ void do_conv_7x7_stride2(
size_t OH, size_t OW, size_t IC);
} // namespace conv_stride2
} // namespace fp32
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp → dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,21 +11,21 @@
*/

#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/block_helper.h"
#include "src/fallback/conv_bias/gi/fp32/algos.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h"

#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "midout.h"

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
using conv_fun = std::function<void(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1)
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw44_stride1)
namespace {

static inline size_t get_perthread_cache_bytes(
@@ -156,7 +156,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_fp32_nchw44_stride1,
megdnn_fallback_conv_bias_fp32_nchw44_stride1,
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
@@ -175,7 +175,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_k
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \
MIDOUT_BEGIN( \
megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \
megdnn_fallback_conv_bias_fp32_nchw44_stride1, \
midout_iv(#filter #bias_mode #stride #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \

dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h → dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw44_kern.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_stride1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -10,10 +10,10 @@
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace arm_common {
namespace fallback {
namespace conv_bias {

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
@@ -28,5 +28,5 @@ void pack_src_fp32_nchw44(
const int pad_top, const int pad_bottom, const int ic, const int ic_stride);

} // namespace conv_bias
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp → dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -1,6 +1,6 @@
/**
* \file
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -12,21 +12,21 @@
*/

#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/gi/fp32/algos.h"
#include "src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"

#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using namespace fallback;
using conv_fun = std::function<void(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44)
MIDOUT_DECL(megdnn_fallback_conv_bias_fp32_nchw_nchw44)
namespace {
static inline int block_helper(
const int nthread, const int amount, const int per_unit_bytes) {
@@ -195,7 +195,7 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_fp32_nchw_nchw44,
megdnn_fallback_conv_bias_fp32_nchw_nchw44,
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
@@ -214,7 +214,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectNCHWNCHW44::
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN( \
megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \
megdnn_fallback_conv_bias_fp32_nchw_nchw44, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \

dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h → dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
* \file dnn/src/fallback/conv_bias/gi/fp32/f32_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,15 +11,14 @@
*/
#pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_helper/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/gi/intrinsic_helper.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/elemwise_helper/elemwise_op.h"
namespace megdnn {
namespace arm_common {
namespace fallback {
namespace fp32_direct_nchw_nchw44 {

static inline void pack_weight_fp32_nchw_nchw44(
@@ -34,8 +33,8 @@ static inline void pack_weight_fp32_nchw_nchw44(
for (int kh_idx = 0; kh_idx < kh; ++kh_idx) {
for (int kw_idx = 0; kw_idx < kw; ++kw_idx) {
for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
float32x4_t vsrc = vld1q_f32(in_ptr_oc);
vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc);
GI_FLOAT32_t vsrc = GiLoadFloat32(in_ptr_oc);
GiStoreFloat32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc);
in_ptr_oc += oc_step;
}
dst_ptr_oc += oc_step;
@@ -51,6 +50,6 @@ void conv_direct_fp32_nchw_nchw44(
const int, const int);
} // namespace fp32_direct_nchw_nchw44

} // namespace arm_common
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/filter_transform.h → dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/filter_transform.h
* \file dnn/src/fallback/conv_bias/gi/fp32/filter_transform.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,14 +11,13 @@

#pragma once
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/utils.h"

namespace megdnn {
namespace arm_common {
namespace fallback {

template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT>
struct FilterTransform6X3 {
@@ -65,8 +64,8 @@ struct FilterTransform6X3 {
Vector<float, 4> g1 = Vector<float, 4>::load(fptr + 3);

Vector<float, 4> g2 = Vector<float, 4>::load(fptr + 6 - 1);
float32x4_t zeros = vdupq_n_f32(0.0f);
g2.value = vextq_f32(g2.value, zeros, 1);
GI_FLOAT32_t zeros = GiZeroFloat32();
g2.value = GiExtqFloat32(g2.value, zeros, 1);

#define cb(i) Vector<float, 4> wd##i;
UNROLL_CALL_NOWRAPPER(8, cb);
@@ -106,7 +105,6 @@ struct FilterTransform6X3 {
}

#else

#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \
@@ -128,7 +126,7 @@ struct FilterTransform6X3 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx)
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value)

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
@@ -154,7 +152,7 @@ struct FilterTransform6X3 {
#undef FILTER_TRANSFORM
#undef GET_VECTOR_ELEM

} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 196
- 0
dnn/src/fallback/conv_bias/gi/fp32/helper.h View File

@@ -0,0 +1,196 @@
/**
* \file dnn/src/fallback/conv_bias/gi/fp32/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once
#include "src/common/unroll_macro.h"
#include "src/fallback/general_intrinsic/gi_float.h"

namespace megdnn {
namespace fallback {
inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) {
GI_FLOAT32_V2_t a0, a1;
a0.val[0] = GiLoadFloat32(src + 0 * lda);
a0.val[1] = GiLoadFloat32(src + 1 * lda);
a1.val[0] = GiLoadFloat32(src + 2 * lda);
a1.val[1] = GiLoadFloat32(src + 3 * lda);
GI_FLOAT32_V2_t b0 = GiZipqFloat32(a0.val[0], a1.val[0]);
GI_FLOAT32_V2_t b1 = GiZipqFloat32(a0.val[1], a1.val[1]);
GI_FLOAT32_V2_t c0 = GiZipqFloat32(b0.val[0], b1.val[0]);
GI_FLOAT32_V2_t c1 = GiZipqFloat32(b0.val[1], b1.val[1]);
GiStoreFloat32(dst + 0 * ldb, c0.val[0]);
GiStoreFloat32(dst + 1 * ldb, c0.val[1]);
GiStoreFloat32(dst + 2 * ldb, c1.val[0]);
GiStoreFloat32(dst + 3 * ldb, c1.val[1]);
}
} // namespace fallback
} // namespace megdnn

#define MATRIX_MUL4x4(sum, a, b) \
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##0, a##0, 0); \
sum##0 = GiMlaqLowLaneFloat32(sum##0, b##1, a##0, 1); \
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##2, a##0, 2); \
sum##0 = GiMlaqHighLaneFloat32(sum##0, b##3, a##0, 3); \
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##0, a##1, 0); \
sum##1 = GiMlaqLowLaneFloat32(sum##1, b##1, a##1, 1); \
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##2, a##1, 2); \
sum##1 = GiMlaqHighLaneFloat32(sum##1, b##3, a##1, 3); \
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##0, a##2, 0); \
sum##2 = GiMlaqLowLaneFloat32(sum##2, b##1, a##2, 1); \
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##2, a##2, 2); \
sum##2 = GiMlaqHighLaneFloat32(sum##2, b##3, a##2, 3); \
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##0, a##3, 0); \
sum##3 = GiMlaqLowLaneFloat32(sum##3, b##1, a##3, 1); \
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##2, a##3, 2); \
sum##3 = GiMlaqHighLaneFloat32(sum##3, b##3, a##3, 3);

#define CONCAT(a, idx) a##idx

#if MEGDNN_AARCH64
//! ret and a are type Vector<float, 8>
#define TRANSPOSE_8x8(a, ret) \
do { \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \
auto b1 = GiZipqFloat32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \
auto b2 = GiZipqFloat32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \
auto b3 = GiZipqFloat32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \
auto b4 = GiZipqFloat32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \
auto b5 = GiZipqFloat32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \
auto b6 = GiZipqFloat32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \
auto b7 = GiZipqFloat32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b2.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[0]), \
GiReinterpretqFloat32ToS64(b6.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b2.val[1]))); \
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b4.val[1]), \
GiReinterpretqFloat32ToS64(b6.val[1]))); \
CONCAT(ret, 4).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 4).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 5).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 5).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[0]), \
GiReinterpretqFloat32ToS64(b7.val[0]))); \
CONCAT(ret, 6).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 6).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
CONCAT(ret, 7).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b1.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 7).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b5.val[1]), \
GiReinterpretqFloat32ToS64(b7.val[1]))); \
} while (0);

#define TRANSPOSE_8x3(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1])));

#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 0).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 1).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[0]), \
GiReinterpretqFloat32ToS64(b1.val[0]))); \
CONCAT(ret, 1).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[0]), \
GiReinterpretqFloat32ToS64(b3.val[0]))); \
CONCAT(ret, 2).value.val[0] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 2).value.val[1] = GiReinterpretqS64ToFloat32(GiZip1qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1]))); \
CONCAT(ret, 3).value.val[0] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b0.val[1]), \
GiReinterpretqFloat32ToS64(b1.val[1]))); \
CONCAT(ret, 3).value.val[1] = GiReinterpretqS64ToFloat32(GiZip2qS64( \
GiReinterpretqFloat32ToS64(b2.val[1]), \
GiReinterpretqFloat32ToS64(b3.val[1])));

#else
#define TRANSPOSE_8x4(a, ret) \
auto b0 = GiZipqFloat32(CONCAT(a, 0).value, CONCAT(a, 1).value); \
auto b1 = GiZipqFloat32(CONCAT(a, 2).value, CONCAT(a, 3).value); \
auto b2 = GiZipqFloat32(CONCAT(a, 4).value, CONCAT(a, 5).value); \
auto b3 = GiZipqFloat32(CONCAT(a, 6).value, CONCAT(a, 7).value); \
CONCAT(ret, 0).value.val[0] = \
GiCombineFloat32(GiGetLowFloat32(b0.val[0]), GiGetLowFloat32(b1.val[0])); \
CONCAT(ret, 1).value.val[0] = GiCombineFloat32( \
GiGetHighFloat32(b0.val[0]), GiGetHighFloat32(b1.val[0])); \
CONCAT(ret, 2).value.val[0] = \
GiCombineFloat32(GiGetLowFloat32(b0.val[1]), GiGetLowFloat32(b1.val[1])); \
CONCAT(ret, 3).value.val[0] = GiCombineFloat32( \
GiGetHighFloat32(b0.val[1]), GiGetHighFloat32(b1.val[1])); \
CONCAT(ret, 0).value.val[1] = \
GiCombineFloat32(GiGetLowFloat32(b2.val[0]), GiGetLowFloat32(b3.val[0])); \
CONCAT(ret, 1).value.val[1] = GiCombineFloat32( \
GiGetHighFloat32(b2.val[0]), GiGetHighFloat32(b3.val[0])); \
CONCAT(ret, 2).value.val[1] = \
GiCombineFloat32(GiGetLowFloat32(b2.val[1]), GiGetLowFloat32(b3.val[1])); \
CONCAT(ret, 3).value.val[1] = GiCombineFloat32( \
GiGetHighFloat32(b2.val[1]), GiGetHighFloat32(b3.val[1]));

#endif
// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy.h → dnn/src/fallback/conv_bias/gi/fp32/strategy.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy.h
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -11,14 +11,15 @@

#pragma once

#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/gi/postprocess_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, winograd_2x3_4x4_f)
MEGDNN_REG_WINOGRAD_STRATEGY(
float, float, float, float, 2, 3, 4, 4, winograd_gi_2x3_4x4_f)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f)

@@ -37,7 +38,7 @@ MEGDNN_REG_WINOGRAD_STRATEGY(
MEGDNN_REG_WINOGRAD_STRATEGY(
float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44)
} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_2x3_4x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F23)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

struct InputTransform2X3 {
@@ -40,15 +39,15 @@ struct InputTransform2X3 {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
auto v0 = vld1q_f32(input_ptr);
auto v1 = vld1q_f32(input_ptr + IW);
auto v2 = vld1q_f32(input_ptr + IW * 2);
auto v3 = vld1q_f32(input_ptr + IW * 3);
vst1q_f32(patch + ico * 4 * alpha + 0 * 4, v0);
vst1q_f32(patch + ico * 4 * alpha + 1 * 4, v1);
vst1q_f32(patch + ico * 4 * alpha + 2 * 4, v2);
vst1q_f32(patch + ico * 4 * alpha + 3 * 4, v3);
auto v0 = GiLoadFloat32(input_ptr);
auto v1 = GiLoadFloat32(input_ptr + IW);
auto v2 = GiLoadFloat32(input_ptr + IW * 2);
auto v3 = GiLoadFloat32(input_ptr + IW * 3);
GiStoreFloat32(patch + ico * 4 * alpha + 0 * 4, v0);
GiStoreFloat32(patch + ico * 4 * alpha + 1 * 4, v1);
GiStoreFloat32(patch + ico * 4 * alpha + 2 * 4, v2);
GiStoreFloat32(patch + ico * 4 * alpha + 3 * 4, v3);
input_ptr += IH * IW;
}
}
@@ -197,18 +196,18 @@ struct OutputTransform2X3 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f)
void winograd_2x3_4x4_f::filter(
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_gi_2x3_4x4_f)
void winograd_gi_2x3_4x4_f::filter(
const float* filter, float* filter_transform_buf, float* transform_mid_buf,
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
//! G * g * GT
float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0},
GI_FLOAT32_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0},
g3{0, 0, 1, 0};
float32x4_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1},
GI_FLOAT32_t gt0{1, 0.5, 0.5, 0}, gt1{0, 0.5, -0.5, 0}, gt2{0, 0.5, 0.5, 1},
gt3{0, 0, 0, 0};
size_t OCB = OC / 4;
size_t ICB = IC / 4;
@@ -225,33 +224,33 @@ void winograd_2x3_4x4_f::filter(
//! 0.5 0.5 0.5 0 v10 v11 v12 0 0 0.5 -0.5 0
//! 0.5 -0.5 0.5 0 v20 v21 v22 0 0 0.5 0.5 1
//! 0 0 1 0 0 0 0 0 0 0 0 0
float32x4_t vf0 = vld1q_f32(filter_ptr);
float32x4_t vf1 = vld1q_f32(filter_ptr + 4);
float32x4_t vf2 = vdupq_n_f32(filter_ptr[8]);
float32x4_t v3(vdupq_n_f32(0));
auto vtmp = vextq_f32(vf1, vf2, 2);
vtmp = vsetq_lane_f32(0, vtmp, 3);
float32x4_t v2(vtmp);
vtmp = vextq_f32(vf0, vf1, 3);
vtmp = vsetq_lane_f32(0, vtmp, 3);
float32x4_t v1(vtmp);
vtmp = vsetq_lane_f32(0, vf0, 3);
float32x4_t v0(vtmp);
float32x4_t vsum0 = vdupq_n_f32(0), vsum1 = vdupq_n_f32(0),
vsum2 = vdupq_n_f32(0), vsum3 = vdupq_n_f32(0);
GI_FLOAT32_t vf0 = GiLoadFloat32(filter_ptr);
GI_FLOAT32_t vf1 = GiLoadFloat32(filter_ptr + 4);
GI_FLOAT32_t vf2 = GiBroadcastFloat32(filter_ptr[8]);
GI_FLOAT32_t v3(GiBroadcastFloat32(0));
auto vtmp = GiExtqFloat32(vf1, vf2, 2);
vtmp = GiSetqLaneFloat32(0, vtmp, 3);
GI_FLOAT32_t v2(vtmp);
vtmp = GiExtqFloat32(vf0, vf1, 3);
vtmp = GiSetqLaneFloat32(0, vtmp, 3);
GI_FLOAT32_t v1(vtmp);
vtmp = GiSetqLaneFloat32(0, vf0, 3);
GI_FLOAT32_t v0(vtmp);
GI_FLOAT32_t vsum0 = GiBroadcastFloat32(0), vsum1 = GiBroadcastFloat32(0),
vsum2 = GiBroadcastFloat32(0), vsum3 = GiBroadcastFloat32(0);

MATRIX_MUL4x4(vsum, g, v);

float32x4_t vres0 = vdupq_n_f32(0), vres1 = vdupq_n_f32(0),
vres2 = vdupq_n_f32(0), vres3 = vdupq_n_f32(0);
GI_FLOAT32_t vres0 = GiBroadcastFloat32(0), vres1 = GiBroadcastFloat32(0),
vres2 = GiBroadcastFloat32(0), vres3 = GiBroadcastFloat32(0);
MATRIX_MUL4x4(vres, vsum, gt);

vst1q_f32(transform_mid_buf, vres0);
vst1q_f32(transform_mid_buf + 4, vres1);
vst1q_f32(transform_mid_buf + 8, vres2);
vst1q_f32(transform_mid_buf + 12, vres3);
GiStoreFloat32(transform_mid_buf, vres0);
GiStoreFloat32(transform_mid_buf + 4, vres1);
GiStoreFloat32(transform_mid_buf + 8, vres2);
GiStoreFloat32(transform_mid_buf + 12, vres3);

size_t ocb = oc / 4;
size_t oc4 = oc % 4;
@@ -266,7 +265,7 @@ void winograd_2x3_4x4_f::filter(
}
}

void winograd_2x3_4x4_f::input(
void winograd_gi_2x3_4x4_f::input(
const float* input, float* input_transform_buf, float* transform_mid_buf,
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
@@ -304,7 +303,7 @@ void winograd_2x3_4x4_f::input(
}
}

void winograd_2x3_4x4_f::output(
void winograd_gi_2x3_4x4_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx,
@@ -322,8 +321,8 @@ void winograd_2x3_4x4_f::output(
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F23, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F23, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
@@ -333,7 +332,7 @@ void winograd_2x3_4x4_f::output(
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_4x5.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F45)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F45)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

struct FilterTransform4X5 {
@@ -126,9 +125,9 @@ struct FilterTransform4X5 {
#undef cb

FILTER_TRANSFORM(g, Gg)
float32x4x2_t vgr;
float32x4_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3};
float32x4_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7};
GI_FLOAT32_V2_t vgr;
GI_FLOAT32_t vgr0 = {Ggr0, Ggr1, Ggr2, Ggr3};
GI_FLOAT32_t vgr1 = {Ggr4, Ggr5, Ggr6, Ggr7};
vgr.val[0] = vgr0; //{Ggr0, Ggr1, Ggr2, Ggr3};
vgr.val[1] = vgr1; //{Ggr4, Ggr5, Ggr6, Ggr7};
Vector<float, 8> Ggt4(vgr);
@@ -167,8 +166,10 @@ struct InputTransform4X5 {
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)

#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])

template <bool inner>
static void transform(
@@ -345,22 +346,22 @@ struct OutputTransform4X5 {
#undef cb

if (oh_start + 4 <= OH && ow_start + 4 <= OW) {
float32x4_t bias0;
GI_FLOAT32_t bias0;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
bias0 = GiBroadcastFloat32(bias[oc]);
}
rep(i, 4) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1);

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
item0 = GiAddFloat32(item0, bias0);
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
item0 = vaddq_f32(item0, bias0);
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start);
item0 = GiAddFloat32(item0, bias0);
}
item0 = op(item0);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0);
mid_buf1 += 4;
}
} else {
@@ -388,7 +389,7 @@ struct OutputTransform4X5 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f)
@@ -448,8 +449,8 @@ void winograd_4x5_1x1_f::output(
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F45, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F45, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
@@ -459,7 +460,7 @@ void winograd_4x5_1x1_f::output(
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_5x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F54)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F54)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

struct FilterTransform5X4 {
@@ -94,7 +93,6 @@ struct FilterTransform5X4 {
transform_mid_buf[j * alpha + i];
}
#else

#define cb(i) \
do { \
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \
@@ -117,7 +115,7 @@ struct FilterTransform5X4 {
mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \
mid_buf1 += 8; \
} while (0);
#define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx)
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i).value)

float* mid_buf1 = transform_mid_buf;
UNROLL_CALL_NOWRAPPER(8, cb);
@@ -154,8 +152,10 @@ struct InputTransform5X4 {
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0)

#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])

template <bool inner>
static void transform(
@@ -348,29 +348,29 @@ struct OutputTransform5X4 {
#undef cb

if (oh_start + 5 <= OH && ow_start + 5 <= OW) {
float32x4_t bias0;
GI_FLOAT32_t bias0;
float32_t bias1;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
bias0 = GiBroadcastFloat32(bias[oc]);
bias1 = bias[oc];
}
rep(i, 5) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1);
float32_t item1 = mid_buf1[4];

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
item0 = GiAddFloat32(item0, bias0);
item1 = item1 + bias1;
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start);
bias1 = bias[oc * OH * OW + oh * OW + ow_start + 4];
item0 = vaddq_f32(item0, bias0);
item0 = GiAddFloat32(item0, bias0);
item1 = item1 + bias1;
}
item0 = op(item0);
item1 = op(item1);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0);
output[oc * OH * OW + oh * OW + ow_start + 4] = item1;

mid_buf1 += 5;
@@ -400,7 +400,7 @@ struct OutputTransform5X4 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f)
@@ -461,8 +461,8 @@ void winograd_5x4_1x1_f::output(
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F54, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F54, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
@@ -472,7 +472,7 @@ void winograd_5x4_1x1_f::output(
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

/**
@@ -57,8 +56,10 @@ namespace {
wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \
} while (0);

#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx)
#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx)
#define GET_VECTOR_HIGH_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[1])
#define GET_VECTOR_LOW_ELEM(s, i, idx) \
GiExtractLane##idx##Float32(CONCAT(s, i).value.val[0])
struct InputTransform6X3 {
template <bool inner>
static void transform(
@@ -271,31 +272,31 @@ struct OutputTransform6X3 {
#undef cb

if (oh_start + 6 <= OH && ow_start + 6 <= OW) {
float32x4_t bias0;
GI_FLOAT32_t bias0;
float32x2_t bias1;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias0 = vdupq_n_f32(bias[oc]);
bias1 = vdup_n_f32(bias[oc]);
bias0 = GiBroadcastFloat32(bias[oc]);
bias1 = GiDupFloat32(bias[oc]);
}
rep(i, 6) {
size_t oh = oh_start + i;
float32x4_t item0 = vld1q_f32(mid_buf1);
float32x2_t item1 = vld1_f32(mid_buf1 + 4);
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1);
float32x2_t item1 = GiLdFloat32(mid_buf1 + 4);

if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
item0 = vaddq_f32(item0, bias0);
item1 = vadd_f32(item1, bias1);
item0 = GiAddFloat32(item0, bias0);
item1 = GiAddDFloat32(item1, bias1);
} else if (bmode == BiasMode::BIAS) {
bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start);
bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + 4);
item0 = vaddq_f32(item0, bias0);
item1 = vadd_f32(item1, bias1);
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start);
bias1 = GiLdFloat32(bias + oc * OH * OW + oh * OW + ow_start + 4);
item0 = GiAddFloat32(item0, bias0);
item1 = GiAddDFloat32(item1, bias1);
}
item0 = op(item0);
item1 = vset_lane_f32(op(vget_lane_f32(item1, 0)), item1, 0);
item1 = vset_lane_f32(op(vget_lane_f32(item1, 1)), item1, 1);
vst1q_f32(output + oc * OH * OW + oh * OW + ow_start, item0);
vst1_f32(output + oc * OH * OW + oh * OW + ow_start + 4, item1);
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 0)), item1, 0);
item1 = GiSetLaneFloat32(op(GiGetLaneFloat32(item1, 1)), item1, 1);
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0);
GiSt1Float32(output + oc * OH * OW + oh * OW + ow_start + 4, item1);

mid_buf1 += 6;
}
@@ -325,7 +326,7 @@ struct OutputTransform6X3 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f)
@@ -385,8 +386,8 @@ void winograd_6x3_1x1_f::output(
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F63, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
@@ -396,7 +397,7 @@ void winograd_6x3_1x1_f::output(
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_6x3_4x4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_4x4)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_4x4)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

@@ -41,16 +40,16 @@ struct InputTransform6X3 {
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
#define cb(i) \
auto v##i##0 = vld1q_f32(input_ptr + IW * i); \
auto v##i##1 = vld1q_f32(input_ptr + IW * i + 4);
#define cb(i) \
auto v##i##0 = GiLoadFloat32(input_ptr + IW * i); \
auto v##i##1 = GiLoadFloat32(input_ptr + IW * i + 4);

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) \
vst1q_f32(patch + ico * 8 * alpha + i * 8, v##i##0); \
vst1q_f32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1);
#define cb(i) \
GiStoreFloat32(patch + ico * 8 * alpha + i * 8, v##i##0); \
GiStoreFloat32(patch + ico * 8 * alpha + i * 8 + 4, v##i##1);

UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
@@ -255,7 +254,7 @@ struct OutputTransform6X3 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f)
@@ -323,8 +322,8 @@ void winograd_6x3_4x4_f::output(
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63_4x4, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F63_4x4, cb, float, float, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
@@ -334,7 +333,7 @@ void winograd_6x3_4x4_f::output(
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f23_mk4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4)
MIDOUT_DECL(megdnn_fallback_winograd_nchw44_fp32_F23_mk4)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;
namespace {

constexpr size_t alpha = 2 + 3 - 1;
@@ -72,8 +71,9 @@ struct InputTransformF23_NCHW44 {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size);
vst1q_f32(patchT + iho * alpha * pack_size + iwo * pack_size, src);
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size);
GiStoreFloat32(
patchT + iho * alpha * pack_size + iwo * pack_size, src);
}
}
#define cb(m, n) \
@@ -190,7 +190,7 @@ struct OutputTransformF23_NCHW44 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44)
@@ -313,14 +313,14 @@ void winograd_F23_mk4_f_nchw44::output(
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");

DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode,
nonline_mode);
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f63_mk4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_mk4)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F63_mk4)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

@@ -49,11 +48,11 @@ struct InputTransformF63_NCHW44 {
const float* input_ptr =
input + icb * IH * IW4 + ih_start * IW4 + iw4_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i);
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb

#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i);
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER(8, cb);
#undef cb
input_ptr += IW4;
@@ -68,8 +67,9 @@ struct InputTransformF63_NCHW44 {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size);
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src);
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size);
GiStoreFloat32(
patchT + iho * pack_size * alpha + iwo * pack_size, src);
}
}
}
@@ -83,10 +83,10 @@ struct InputTransformF63_NCHW44 {
size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;

float32x4_t d0, d1, d2, d3, d4, d5, d6, d7;
float32x4_t v0 = vld1q_f32(input_parameters + 0);
float32x4_t v1 = vld1q_f32(input_parameters + 4);
float32x4_t v2 = vld1q_f32(input_parameters + 8);
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7;
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0);
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4);
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8);

//! B
//! 1 0 0 0 0 0 0 0
@@ -98,57 +98,57 @@ struct InputTransformF63_NCHW44 {
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1

#define cb(i) \
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = t##i##0 - d6; \
t##i##1 = t##i##1 + d1; \
t##i##2 = t##i##2 - d1; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \
t##i##7 = t##i##7 - d1; \
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \
t##i##1 = t##i##1 + d2; \
t##i##2 = t##i##2 + d2; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \
t##i##1 = t##i##1 + d5; \
t##i##2 = t##i##2 - d5; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0);
#define cb(i) \
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = t##i##0 - d6; \
t##i##1 = t##i##1 + d1; \
t##i##2 = t##i##2 - d1; \
t##i##3 = GiSimdFmaLane(t##i##3, d1, v0, 2); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d1, v0, 2); \
t##i##5 = GiSimdFmaLane(t##i##5, d1, v1, 2); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d1, v1, 2); \
t##i##7 = t##i##7 - d1; \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 0); \
t##i##1 = t##i##1 + d2; \
t##i##2 = t##i##2 + d2; \
t##i##3 = GiSimdFmaLane(t##i##3, d2, v0, 3); \
t##i##4 = GiSimdFmaLane(t##i##4, d2, v0, 3); \
t##i##5 = GiSimdFmaLane(t##i##5, d2, v1, 3); \
t##i##6 = GiSimdFmaLane(t##i##6, d2, v1, 3); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d3, v0, 1); \
t##i##2 = GiSimdFmaLane(t##i##2, d3, v0, 1); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d3, v1, 0); \
t##i##4 = GiSimdFmaLane(t##i##4, d3, v1, 0); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d3, v1, 0); \
t##i##6 = GiSimdFmaLane(t##i##6, d3, v1, 0); \
t##i##7 = GiSimdFmaLane(t##i##7, d3, v0, 0); \
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 0); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d4, v0, 1); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d4, v0, 1); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v1, 1); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d4, v1, 1); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d4, v2, 0); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d4, v2, 0); \
t##i##1 = t##i##1 + d5; \
t##i##2 = t##i##2 - d5; \
t##i##3 = GiSimdFmaLane(t##i##3, d5, v1, 2); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d5, v1, 2); \
t##i##5 = GiSimdFmaLane(t##i##5, d5, v0, 2); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v0, 2); \
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v0, 0);
UNROLL_CALL_RAW(8, cb);
#undef cb

@@ -164,75 +164,75 @@ struct InputTransformF63_NCHW44 {
d0 = d0 - t6##i; \
d1 = d1 + t1##i; \
d2 = d2 - t1##i; \
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \
d3 = GiSimdFmaLane(d3, t1##i, v0, 2); \
d4 = GiFmsqLaneQFloat32(d4, t1##i, v0, 2); \
d5 = GiSimdFmaLane(d5, t1##i, v1, 2); \
d6 = GiFmsqLaneQFloat32(d6, t1##i, v1, 2); \
d7 = d7 - t1##i; \
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 0); \
d1 = d1 + t2##i; \
d2 = d2 + t2##i; \
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \
d3 = GiSimdFmaLane(d3, t2##i, v0, 3); \
d4 = GiSimdFmaLane(d4, t2##i, v0, 3); \
d5 = GiSimdFmaLane(d5, t2##i, v1, 3); \
d6 = GiSimdFmaLane(d6, t2##i, v1, 3); \
d1 = GiFmsqLaneQFloat32(d1, t3##i, v0, 1); \
d2 = GiSimdFmaLane(d2, t3##i, v0, 1); \
d3 = GiFmsqLaneQFloat32(d3, t3##i, v1, 0); \
d4 = GiSimdFmaLane(d4, t3##i, v1, 0); \
d5 = GiFmsqLaneQFloat32(d5, t3##i, v1, 0); \
d6 = GiSimdFmaLane(d6, t3##i, v1, 0); \
d7 = GiSimdFmaLane(d7, t3##i, v0, 0); \
d0 = GiSimdFmaLane(d0, t4##i, v0, 0); \
d1 = GiFmsqLaneQFloat32(d1, t4##i, v0, 1); \
d2 = GiFmsqLaneQFloat32(d2, t4##i, v0, 1); \
d3 = GiFmsqLaneQFloat32(d3, t4##i, v1, 1); \
d4 = GiFmsqLaneQFloat32(d4, t4##i, v1, 1); \
d5 = GiFmsqLaneQFloat32(d5, t4##i, v2, 0); \
d6 = GiFmsqLaneQFloat32(d6, t4##i, v2, 0); \
d1 = d1 + t5##i; \
d2 = d2 - t5##i; \
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \
vst1q_f32( \
d3 = GiSimdFmaLane(d3, t5##i, v1, 2); \
d4 = GiFmsqLaneQFloat32(d4, t5##i, v1, 2); \
d5 = GiSimdFmaLane(d5, t5##i, v0, 2); \
d6 = GiFmsqLaneQFloat32(d6, t5##i, v0, 2); \
d7 = GiFmsqLaneQFloat32(d7, t5##i, v0, 0); \
GiStoreFloat32( \
input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d0); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d1); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d2); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d3); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d4); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d5); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d6); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
@@ -347,7 +347,7 @@ struct OutputTransformF63_NCHW44 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44)
@@ -488,14 +488,14 @@ void winograd_F63_mk4_f_nchw44::output(
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");

DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F63_mk4, cb, float, float, bmode,
nonline_mode);
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp → dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp
* \file dnn/src/fallback/conv_bias/gi/fp32/strategy_f73_mk4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -9,22 +9,21 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h"
#include "src/fallback/conv_bias/gi/fp32/helper.h"
#include "src/fallback/conv_bias/gi/fp32/strategy.h"
#include "src/fallback/conv_bias/gi/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/fallback/elemwise_helper/op_unary.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4)
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F73_mk4)

using namespace megdnn;
using namespace arm_common;
using namespace fallback;

namespace {

@@ -51,11 +50,11 @@ struct InputTransformF73_NCHW44 {
const float* input_ptr =
input + icb * IH * IW4 + ih_start * IW4 + iw4_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i);
#define cb(i) auto v##i = GiLoadFloat32(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb

#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i);
#define cb(i) GiStoreFloat32(patchT + ih * pack_size * alpha + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb
input_ptr += IW4;
@@ -70,8 +69,9 @@ struct InputTransformF73_NCHW44 {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size);
vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src);
auto src = GiLoadFloat32(input_ptr + ih * IW4 + iw * pack_size);
GiStoreFloat32(
patchT + iho * pack_size * alpha + iwo * pack_size, src);
}
}
}
@@ -85,14 +85,14 @@ struct InputTransformF73_NCHW44 {
size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;

float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8;
float32x4_t v0 = vld1q_f32(input_parameters + 0);
float32x4_t v1 = vld1q_f32(input_parameters + 4);
float32x4_t v2 = vld1q_f32(input_parameters + 8);
float32x4_t v3 = vld1q_f32(input_parameters + 12);
float32x4_t v4 = vld1q_f32(input_parameters + 16);
float32x4_t v5 = vld1q_f32(input_parameters + 20);
float32x4_t v6 = vld1q_f32(input_parameters + 24);
GI_FLOAT32_t d0, d1, d2, d3, d4, d5, d6, d7, d8;
GI_FLOAT32_t v0 = GiLoadFloat32(input_parameters + 0);
GI_FLOAT32_t v1 = GiLoadFloat32(input_parameters + 4);
GI_FLOAT32_t v2 = GiLoadFloat32(input_parameters + 8);
GI_FLOAT32_t v3 = GiLoadFloat32(input_parameters + 12);
GI_FLOAT32_t v4 = GiLoadFloat32(input_parameters + 16);
GI_FLOAT32_t v5 = GiLoadFloat32(input_parameters + 20);
GI_FLOAT32_t v6 = GiLoadFloat32(input_parameters + 24);

//! B
//! 1.5 0 0 0 0 0 0 0 0
@@ -113,77 +113,77 @@ struct InputTransformF73_NCHW44 {
// 5.0f, 10.0f, 5.75f, 2.75f, v5
// 4.25f, 1.75f, 2.0f, 0.0f, v6

#define cb(i) \
d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \
d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \
auto t##i##0 = d7; \
auto t##i##1 = d7; \
auto t##i##2 = d7; \
auto t##i##3 = d7; \
auto t##i##4 = d7; \
auto t##i##5 = d7; \
auto t##i##6 = d7; \
auto t##i##7 = d7; \
t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \
t##i##0 = t##i##0 - d1; \
t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \
t##i##7 = t##i##7 - d1; \
t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \
t##i##8 = t##i##8 - d2; \
t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \
t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \
t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \
t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \
t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \
t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \
t##i##5 = t##i##5 - d6; \
t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \
t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0);
#define cb(i) \
d0 = GiLoadFloat32(patchT + i * alpha * pack_size + 0 * pack_size); \
d1 = GiLoadFloat32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = GiLoadFloat32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = GiLoadFloat32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = GiLoadFloat32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = GiLoadFloat32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = GiLoadFloat32(patchT + i * alpha * pack_size + 6 * pack_size); \
d7 = GiLoadFloat32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##8 = GiLoadFloat32(patchT + i * alpha * pack_size + 8 * pack_size); \
auto t##i##0 = d7; \
auto t##i##1 = d7; \
auto t##i##2 = d7; \
auto t##i##3 = d7; \
auto t##i##4 = d7; \
auto t##i##5 = d7; \
auto t##i##6 = d7; \
auto t##i##7 = d7; \
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d7, v0, 0); \
t##i##0 = t##i##0 - d1; \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d1, v0, 0); \
t##i##2 = GiSimdFmaLane(t##i##2, d1, v0, 0); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d1, v0, 1); \
t##i##4 = GiSimdFmaLane(t##i##4, d1, v0, 1); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d1, v0, 2); \
t##i##6 = GiSimdFmaLane(t##i##6, d1, v0, 2); \
t##i##7 = t##i##7 - d1; \
t##i##8 = GiSimdFmaLane(t##i##8, d1, v0, 0); \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d2, v0, 3); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d2, v1, 0); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d2, v1, 1); \
t##i##3 = GiSimdFmaLane(t##i##3, d2, v1, 2); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d2, v1, 3); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d2, v2, 0); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d2, v2, 1); \
t##i##8 = t##i##8 - d2; \
t##i##0 = GiSimdFmaLane(t##i##0, d3, v2, 2); \
t##i##1 = GiSimdFmaLane(t##i##1, d3, v2, 3); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d3, v3, 0); \
t##i##3 = GiSimdFmaLane(t##i##3, d3, v2, 0); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d3, v3, 1); \
t##i##5 = GiSimdFmaLane(t##i##5, d3, v3, 2); \
t##i##6 = GiSimdFmaLane(t##i##6, d3, v3, 3); \
t##i##7 = GiSimdFmaLane(t##i##7, d3, v2, 2); \
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d3, v0, 3); \
t##i##0 = GiSimdFmaLane(t##i##0, d4, v0, 3); \
t##i##1 = GiSimdFmaLane(t##i##1, d4, v4, 0); \
t##i##2 = GiSimdFmaLane(t##i##2, d4, v4, 1); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d4, v4, 2); \
t##i##4 = GiSimdFmaLane(t##i##4, d4, v4, 3); \
t##i##5 = GiSimdFmaLane(t##i##5, d4, v5, 0); \
t##i##6 = GiSimdFmaLane(t##i##6, d4, v5, 1); \
t##i##8 = GiSimdFmaLane(t##i##8, d4, v2, 2); \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d5, v2, 2); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d5, v5, 2); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d5, v5, 3); \
t##i##3 = GiFmsqLaneQFloat32(t##i##3, d5, v6, 0); \
t##i##4 = GiSimdFmaLane(t##i##4, d5, v6, 1); \
t##i##5 = GiFmsqLaneQFloat32(t##i##5, d5, v5, 2); \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d5, v6, 0); \
t##i##7 = GiFmsqLaneQFloat32(t##i##7, d5, v2, 2); \
t##i##8 = GiSimdFmaLane(t##i##8, d5, v0, 3); \
t##i##0 = GiFmsqLaneQFloat32(t##i##0, d6, v0, 0); \
t##i##1 = GiFmsqLaneQFloat32(t##i##1, d6, v1, 0); \
t##i##2 = GiFmsqLaneQFloat32(t##i##2, d6, v1, 1); \
t##i##3 = GiSimdFmaLane(t##i##3, d6, v1, 0); \
t##i##4 = GiFmsqLaneQFloat32(t##i##4, d6, v3, 1); \
t##i##5 = t##i##5 - d6; \
t##i##6 = GiFmsqLaneQFloat32(t##i##6, d6, v6, 2); \
t##i##8 = GiFmsqLaneQFloat32(t##i##8, d6, v2, 2); \
t##i##0 = GiSimdFmaLane(t##i##0, d0, v0, 0);

UNROLL_CALL_RAW(9, cb);
#undef cb
@@ -198,100 +198,100 @@ struct InputTransformF73_NCHW44 {
d5 = t7##i; \
d6 = t7##i; \
d7 = t7##i; \
d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \
d8 = GiFmsqLaneQFloat32(d8, t7##i, v0, 0); \
d0 = d0 - t1##i; \
d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \
d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \
d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \
d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \
d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \
d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \
d1 = GiFmsqLaneQFloat32(d1, t1##i, v0, 0); \
d2 = GiSimdFmaLane(d2, t1##i, v0, 0); \
d3 = GiFmsqLaneQFloat32(d3, t1##i, v0, 1); \
d4 = GiSimdFmaLane(d4, t1##i, v0, 1); \
d5 = GiFmsqLaneQFloat32(d5, t1##i, v0, 2); \
d6 = GiSimdFmaLane(d6, t1##i, v0, 2); \
d7 = d7 - t1##i; \
d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \
d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \
d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \
d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \
d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \
d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \
d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \
d8 = GiSimdFmaLane(d8, t1##i, v0, 0); \
d0 = GiFmsqLaneQFloat32(d0, t2##i, v0, 3); \
d1 = GiFmsqLaneQFloat32(d1, t2##i, v1, 0); \
d2 = GiFmsqLaneQFloat32(d2, t2##i, v1, 1); \
d3 = GiSimdFmaLane(d3, t2##i, v1, 2); \
d4 = GiFmsqLaneQFloat32(d4, t2##i, v1, 3); \
d5 = GiFmsqLaneQFloat32(d5, t2##i, v2, 0); \
d6 = GiFmsqLaneQFloat32(d6, t2##i, v2, 1); \
d8 = d8 - t2##i; \
d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \
d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \
d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \
d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \
d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \
d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \
d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \
d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \
d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \
d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \
d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \
d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \
d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \
d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \
d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \
d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \
d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \
d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \
d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \
d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \
d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \
d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \
d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \
d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \
d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \
d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \
d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \
d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \
d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \
d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \
d0 = GiSimdFmaLane(d0, t3##i, v2, 2); \
d1 = GiSimdFmaLane(d1, t3##i, v2, 3); \
d2 = GiFmsqLaneQFloat32(d2, t3##i, v3, 0); \
d3 = GiSimdFmaLane(d3, t3##i, v2, 0); \
d4 = GiFmsqLaneQFloat32(d4, t3##i, v3, 1); \
d5 = GiSimdFmaLane(d5, t3##i, v3, 2); \
d6 = GiSimdFmaLane(d6, t3##i, v3, 3); \
d7 = GiSimdFmaLane(d7, t3##i, v2, 2); \
d8 = GiFmsqLaneQFloat32(d8, t3##i, v0, 3); \
d0 = GiSimdFmaLane(d0, t4##i, v0, 3); \
d1 = GiSimdFmaLane(d1, t4##i, v4, 0); \
d2 = GiSimdFmaLane(d2, t4##i, v4, 1); \
d3 = GiFmsqLaneQFloat32(d3, t4##i, v4, 2); \
d4 = GiSimdFmaLane(d4, t4##i, v4, 3); \
d5 = GiSimdFmaLane(d5, t4##i, v5, 0); \
d6 = GiSimdFmaLane(d6, t4##i, v5, 1); \
d8 = GiSimdFmaLane(d8, t4##i, v2, 2); \
d0 = GiFmsqLaneQFloat32(d0, t5##i, v2, 2); \
d1 = GiFmsqLaneQFloat32(d1, t5##i, v5, 2); \
d2 = GiFmsqLaneQFloat32(d2, t5##i, v5, 3); \
d3 = GiFmsqLaneQFloat32(d3, t5##i, v6, 0); \
d4 = GiSimdFmaLane(d4, t5##i, v6, 1); \
d5 = GiFmsqLaneQFloat32(d5, t5##i, v5, 2); \
d6 = GiFmsqLaneQFloat32(d6, t5##i, v6, 0); \
d7 = GiFmsqLaneQFloat32(d7, t5##i, v2, 2); \
d8 = GiSimdFmaLane(d8, t5##i, v0, 3); \
d0 = GiFmsqLaneQFloat32(d0, t6##i, v0, 0); \
d1 = GiFmsqLaneQFloat32(d1, t6##i, v1, 0); \
d2 = GiFmsqLaneQFloat32(d2, t6##i, v1, 1); \
d3 = GiSimdFmaLane(d3, t6##i, v1, 0); \
d4 = GiFmsqLaneQFloat32(d4, t6##i, v3, 1); \
d5 = d5 - t6##i; \
d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \
d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \
d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \
vst1q_f32( \
d6 = GiFmsqLaneQFloat32(d6, t6##i, v6, 2); \
d8 = GiFmsqLaneQFloat32(d8, t6##i, v2, 2); \
d0 = GiSimdFmaLane(d0, t0##i, v0, 0); \
GiStoreFloat32( \
input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d0); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d1); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d2); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d3); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d4); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d5); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d6); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
d7); \
vst1q_f32( \
GiStoreFloat32( \
input_transform_buf + \
(8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \
@@ -413,7 +413,7 @@ struct OutputTransformF73_NCHW44 {
} // namespace

namespace megdnn {
namespace arm_common {
namespace fallback {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44)
@@ -554,14 +554,14 @@ void winograd_F73_mk4_f_nchw44::output(
OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");

DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp32_F73_mk4, cb, float, float, bmode,
GI_DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_fallback_winograd_fp32_F73_mk4, cb, float, float, bmode,
nonline_mode);
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 413
- 0
dnn/src/fallback/conv_bias/gi/intrinsic_helper.h View File

@@ -0,0 +1,413 @@
/**
* \file dnn/src/fallback/conv_bias/gi/intrinsic_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/unroll_macro.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"

namespace megdnn {
namespace {

struct Vld1qF32S {
static GI_FORCEINLINE GI_FLOAT32_t impl(const float32_t* ptr) {
return GiLoadFloat32(ptr);
}
};

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuninitialized"

#ifdef __GNUC__
#ifndef __has_warning
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#else
#if __has_warning("-Wmaybe-uninitialized")
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#endif

template <
int weight_number, int base_offset, int ptr_step, int oc_block, typename Func,
typename T, typename T2, typename... XT>
struct LoadHelper {
static GI_FORCEINLINE void impl(T& weight, T2 ptr, int oc_offset, XT... args);
};

#define WEIGHT_CB(step) \
src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...);

#define LOAD_HELPER(step) \
template < \
int base_offset, int ptr_step, typename Func, typename T, typename T2, \
typename... XT> \
struct LoadHelper<step, base_offset, ptr_step, 0, Func, T, T2, XT...> { \
static GI_FORCEINLINE void impl(T& src, T2 ptr, int, XT... args) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \
}

LOAD_HELPER(1);
LOAD_HELPER(2);
LOAD_HELPER(3);
LOAD_HELPER(4);
LOAD_HELPER(5);
LOAD_HELPER(6);
LOAD_HELPER(7);
LOAD_HELPER(8);
LOAD_HELPER(9);
LOAD_HELPER(10);
LOAD_HELPER(11);
LOAD_HELPER(12);
LOAD_HELPER(13);
LOAD_HELPER(14);
LOAD_HELPER(15);
LOAD_HELPER(16);

#undef LOAD_HELPER
#undef WEIGHT_CB

///////////////////////////c_dim = 1/////////////////////////
#define WEIGHT_CB(step) src[0][step] = Func::impl(ptr + base_offset + step * ptr_step);

#define LOAD_HELPER(step) \
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \
struct LoadHelper<step, base_offset, ptr_step, 1, Func, T, T2> { \
static GI_FORCEINLINE void impl(T& src, T2 ptr, int) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \
}

LOAD_HELPER(1);
LOAD_HELPER(2);
LOAD_HELPER(3);
LOAD_HELPER(4);
LOAD_HELPER(5);
LOAD_HELPER(6);
LOAD_HELPER(7);
LOAD_HELPER(8);
LOAD_HELPER(9);

#undef LOAD_HELPER
#undef WEIGHT_CB

/////////////////////////c_dim = 2///////////////////////////////
#define WEIGHT_CB(step) \
src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \
src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset);

#define LOAD_HELPER(step) \
template <int base_offset, int ptr_step, typename Func, typename T, typename T2> \
struct LoadHelper<step, base_offset, ptr_step, 2, Func, T, T2> { \
static GI_FORCEINLINE void impl(T& src, T2 ptr, int oc_offset) { \
UNROLL_CALL_RAW(step, WEIGHT_CB); \
} \
}

LOAD_HELPER(1);
LOAD_HELPER(2);
LOAD_HELPER(3);
LOAD_HELPER(4);
LOAD_HELPER(5);
LOAD_HELPER(6);
LOAD_HELPER(7);
LOAD_HELPER(8);

#undef LOAD_HELPER
#undef WEIGHT_CB

template <
int weight_number, int base_offset, int ptr_step, int c_dim, typename Func,
typename T, typename T2>
GI_FORCEINLINE void load_helper(T& weight, T2 ptr, int oc_offset) {
LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl(
weight, ptr, oc_offset);
}

////////////////////Store_OCX_OW8_Remain/////////////////////////
template <int c_dim, int ow_remain, typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc);
};

template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op({{c[1][6], c[1][7]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op({{c[1][4], c[1][5]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));

op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc));
}
};

template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24));
}
};

template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr));
}
};
template <typename Op, typename T, typename T2, typename T3>
struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> {
static GI_FORCEINLINE void impl(T& c, const Op& op, T2 dst_ptr, int) {
op(c[0][0], reinterpret_cast<T3>(dst_ptr));
}
};

template <int c_dim, int ow_remain, typename Op, typename T, typename T2>
GI_FORCEINLINE void store_ocx_ow8_remain_static(
T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) {
StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, ld_dst_oc);
}

#undef cb
#undef cb2
#undef cb_case
#undef cb_case2

#pragma GCC diagnostic pop

/////////////////////////init_ocx_ow8////////////////////

template <typename T>
struct GiLdqSimd;
template <>
struct GiLdqSimd<float> {
static constexpr int simd_len = 4;
};
template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2>
struct InitOcxOw8 {
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step);
};
template <int c_dim, BiasMode bias_mode, typename T, typename T2>
struct InitOcxOw8<c_dim, bias_mode, 0, T, T2> {
static GI_FORCEINLINE void impl(T&, const T2*, int) {}
};

#define BAIS_INIT_NO_BIAS_C2(step) \
c[0][step] = GiBroadcastFloat32(static_cast<T2>(0)); \
c[1][step] = GiBroadcastFloat32(static_cast<T2>(0));
#define BAIS_INIT_NO_BIAS_C1(step) c[0][step] = GiBroadcastFloat32(static_cast<T2>(0));

#define BAIS_INIT_BROADCAST_C2(step) \
c[0][step] = GiLoadFloat32(bias_ptr); \
c[1][step] = GiLoadFloat32(bias_ptr + oc_step);
#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = GiLoadFloat32(bias_ptr);

#define BAIS_INIT_BIAS_C2(step) \
c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len); \
c[1][step] = GiLoadFloat32(bias_ptr + oc_step + step * simd_len);

#define BAIS_INIT_BIAS_C1(step) c[0][step] = GiLoadFloat32(bias_ptr + step * simd_len);

#define INSTANCE_InitOcxOw8(ow_remain, cdim) \
template <typename T, typename T2> \
struct InitOcxOw8<cdim, BiasMode::NO_BIAS, ow_remain, T, T2> { \
static GI_FORCEINLINE void impl(T& c, const T2*, int) { \
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \
} \
}; \
template <typename T, typename T2> \
struct InitOcxOw8<cdim, BiasMode::BROADCAST_CHANNEL_BIAS, ow_remain, T, T2> { \
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \
(void)oc_step; \
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \
} \
}; \
template <typename T, typename T2> \
struct InitOcxOw8<cdim, BiasMode::BIAS, ow_remain, T, T2> { \
static GI_FORCEINLINE void impl(T& c, const T2* bias_ptr, int oc_step) { \
constexpr int simd_len = GiLdqSimd<T2>::simd_len; \
(void)oc_step; \
UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \
} \
};
#define INSTANCE_InitOcxOw8_C(ow_remain) \
INSTANCE_InitOcxOw8(ow_remain, 2); \
INSTANCE_InitOcxOw8(ow_remain, 1);

INSTANCE_InitOcxOw8_C(1);
INSTANCE_InitOcxOw8_C(2);
INSTANCE_InitOcxOw8_C(3);
INSTANCE_InitOcxOw8_C(4);
INSTANCE_InitOcxOw8_C(5);
INSTANCE_InitOcxOw8_C(6);
INSTANCE_InitOcxOw8_C(7);
INSTANCE_InitOcxOw8_C(8);

#undef INSTANCE_InitOcxOw8
#undef INSTANCE_InitOcxOw8_C
#undef BAIS_INIT_BIAS_C1
#undef BAIS_INIT_BIAS_C2
#undef BAIS_INIT_BROADCAST_C1
#undef BAIS_INIT_BROADCAST_C2
#undef BAIS_INIT_NO_BIAS_C1
#undef BAIS_INIT_NO_BIAS_C2

template <int c_dim, BiasMode bias_mode, int ow_remain, typename T, typename T2>
GI_FORCEINLINE void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) {
InitOcxOw8<c_dim, bias_mode, ow_remain, T, T2>::impl(c, bias_ptr, oc_step);
}

} // namespace
} // namespace megdnn
#undef GI_FORCEINLINE
// vim: syntax=cpp.doxygen

+ 86
- 0
dnn/src/fallback/conv_bias/gi/postprocess_helper.h View File

@@ -0,0 +1,86 @@
/**
* \file dnn/src/fallback/conv_bias/gi/postprocess_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megdnn/basic_types.h"
#include "src/fallback/conv_bias/opr_impl.h"

#include "midout.h"

MIDOUT_DECL(fallback_gi_conv_bias_postprocess_helper)

namespace {

#define GI_DISPATCH_CONV_WINOGRAD_NONLINE( \
_midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \
switch (_nonline_mode) { \
case param::ConvBias::NonlineMode::IDENTITY: { \
MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \
cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \
} \
MIDOUT_END(); \
break; \
} \
case param::ConvBias::NonlineMode::RELU: { \
MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \
cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \
} \
MIDOUT_END(); \
break; \
} \
case param::ConvBias::NonlineMode::SIGMOID: { \
MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \
cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \
} \
MIDOUT_END(); \
break; \
} \
case param::ConvBias::NonlineMode::H_SWISH: { \
MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \
cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \
} \
MIDOUT_END(); \
break; \
} \
default: \
megdnn_assert(0); \
break; \
}

#define GI_DISPATCH_CONV_WINOGRAD_BIAS( \
_midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \
switch (_bmode) { \
case BiasMode::BIAS: { \
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \
_midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \
_nonline_mode, __VA_ARGS__) \
break; \
} \
case BiasMode::NO_BIAS: { \
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \
_midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \
_nonline_mode, __VA_ARGS__) \
break; \
} \
case BiasMode::BROADCAST_CHANNEL_BIAS: { \
GI_DISPATCH_CONV_WINOGRAD_NONLINE( \
_midout_tag, cb, 2, _src_type, _dst_type, \
BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \
break; \
} \
default: \
megdnn_assert(0); \
break; \
}

} // namespace

+ 193
- 0
dnn/src/fallback/conv_bias/gi/utils.h View File

@@ -0,0 +1,193 @@
/**
* \file dnn/src/fallback/conv_bias/gi/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include <cstring>
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"

namespace megdnn {
namespace fallback {

template <typename ctype, size_t len>
struct Vector;

template <>
struct Vector<float, 4> {
GI_FLOAT32_t value;
Vector() {}
Vector(const float v) { value = GiBroadcastFloat32(v); }
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_t& v) { value = v; }
static Vector load(const float* addr) {
Vector v;
v.value = GiLoadFloat32(addr);
return v;
}
static void save(float* addr, const Vector& v) { GiStoreFloat32(addr, v.value); }
void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value = GiAddFloat32(value, lr.value);
return dst;
}
Vector& operator+=(const Vector& lr) {
value = GiAddFloat32(value, lr.value);
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value = GiSubtractFloat32(value, lr.value);
return dst;
}
Vector& operator-=(const Vector& lr) {
value = GiSubtractFloat32(value, lr.value);
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value = GiMultiplyScalerFloat32(value, lr);
return dst;
}
Vector operator*(const Vector& lr) {
Vector dst;
dst.value = GiMultiplyFloat32(value, lr.value);
return dst;
}
Vector& operator*=(const Vector& lr) {
value = GiMultiplyFloat32(value, lr.value);
return *this;
}
Vector& operator=(const Vector& lr) {
value = lr.value;
return *this;
}
Vector& operator=(const Vector&& lr) {
value = std::move(lr.value);
return *this;
}
Vector operator-() {
Vector dst;
dst.value = -value;
return dst;
}
};

template <>
struct Vector<float, 8> {
GI_FLOAT32_V2_t value;
Vector() {}
Vector(const float v) {
value.val[0] = GiBroadcastFloat32(v);
value.val[1] = GiBroadcastFloat32(v);
}
Vector(const Vector& lr) { value = lr.value; }
Vector(const Vector&& lr) { value = std::move(lr.value); }
Vector(const GI_FLOAT32_V2_t& v) { value = v; }
static Vector load(const float* addr) {
Vector v;
#if defined(GI_TEST_NAIVE)
v.value.val[0] = GiLoadFloat32(addr);
v.value.val[1] = GiLoadFloat32(addr + 4);
#elif defined(__arm__) || defined(__aarch64__)
v.value = vld1q_f32_x2(addr);
#else
v.value.val[0] = GiLoadFloat32(addr);
v.value.val[1] = GiLoadFloat32(addr + 4);
#endif
return v;
}
static void save(float* addr, const Vector& v) {
#if defined(GI_TEST_NAIVE)
GiStoreFloat32(addr, v.value.val[0]);
GiStoreFloat32(addr + 4, v.value.val[1]);
#elif defined(__arm__) || defined(__aarch64__)
vst1q_f32_x2(addr, v.value);
#else
GiStoreFloat32(addr, v.value.val[0]);
GiStoreFloat32(addr + 4, v.value.val[1]);
#endif
}

void save(float* addr) { save(addr, *this); }
Vector operator+(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
return dst;
}
Vector& operator+=(const Vector& lr) {
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
return *this;
}
Vector& add(const Vector& lr) {
value.val[0] = GiAddFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiAddFloat32(value.val[1], lr.value.val[1]);
return *this;
}
Vector operator-(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]);
return dst;
}
Vector& operator-=(const Vector& lr) {
value.val[0] = GiSubtractFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiSubtractFloat32(value.val[1], lr.value.val[1]);
return *this;
}
Vector operator*(float lr) {
Vector dst;
dst.value.val[0] = GiMultiplyScalerFloat32(value.val[0], lr);
dst.value.val[1] = GiMultiplyScalerFloat32(value.val[1], lr);
return dst;
}
//! val + lr * n
Vector& mla(const Vector& lr, float n) {
value.val[0] = GiMultiplyAddScalarFloat32(value.val[0], lr.value.val[0], n);
value.val[1] = GiMultiplyAddScalarFloat32(value.val[1], lr.value.val[1], n);
return *this;
}

Vector operator*(const Vector& lr) {
Vector dst;
dst.value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]);
dst.value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]);
return dst;
}
Vector& operator*=(const Vector& lr) {
value.val[0] = GiMultiplyFloat32(value.val[0], lr.value.val[0]);
value.val[1] = GiMultiplyFloat32(value.val[1], lr.value.val[1]);
return *this;
}
Vector& operator=(const Vector& lr) {
value = lr.value;
return *this;
}
Vector& operator=(const Vector&& lr) {
value = std::move(lr.value);
return *this;
}
Vector operator-() {
Vector dst;
dst.value.val[0] = -value.val[0];
dst.value.val[1] = -value.val[1];
return dst;
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 92
- 4
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -16,6 +16,7 @@
#include "src/fallback/conv_bias/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
#include "src/fallback/conv_bias/gi/fp32/algos.h"
#include "src/fallback/conv_bias/im2col/algos.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/naive/convolution/algorithms.h"
@@ -34,6 +35,14 @@
using namespace megdnn;
using namespace fallback;

namespace {

//! TODO: imp is_fallback_exclude_gi_or_naive
bool is_naive(const detail::Algorithm* algo) {
return algo->handle_type() == Handle::HandleType::NAIVE;
}
} // anonymous namespace

size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
switch (format) {
case param::ConvBias::Format::NCHW44:
@@ -73,16 +82,95 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
SmallVector<std::unique_ptr<AlgoBase>> refhold;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_gi_winograd_algos;

AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
AlgoF32ChannelWiseNCHW44 f32_chanel_wise_nchw44;
AlgoF32DirectNCHW44 f32_direct_nchw44;

AlgoF32Direct f32_direct;
AlgoF32DirectStride2 f32_direct_stride2;
AlgoF32DirectStride1 f32_direct_stride1;

public:
AlgoPack() {
// fallback gi fp32 algo
m_all_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
m_all_algos.emplace_back(&f32_chanel_wise_nchw44);
m_all_algos.emplace_back(&f32_direct_nchw44);
m_all_algos.emplace_back(&f32_direct_stride1);
m_all_algos.emplace_back(&f32_direct_stride2);
m_all_algos.emplace_back(&f32_direct);

static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>();
using MatmulFormat = param::MatrixMul::Format;
auto&& matmul_algos =
static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4});
for (auto&& algo : matmul_algos) {
if (is_naive(algo))
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
#endif
}
}

//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (is_naive(algo))
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
m_gi_winograd_algos.emplace_back(refhold.back().get());
}
}
for (auto&& algo : m_gi_winograd_algos) {
m_all_algos.emplace_back(algo);
}
// end fallback gi fp32 algo

refhold.emplace_back(new AlgoConv1x1Gemv());
m_all_algos.emplace_back(refhold.back().get());

static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->get_all_packed_algo();
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->get_all_packed_algo();
for (auto&& algo : matmul_algos) {
#if MEGDNN_X86
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may


+ 31
- 14
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -226,6 +226,20 @@ public:
FB_CONV1x1,
FB_CONV1x1_GEMV,
FB_IM2COL,
GI_COMMON_WINOGRAD_F23_4X4_FP32,
GI_COMMON_WINOGRAD_F63_FP32,
GI_COMMON_WINOGRAD_F63_4X4_FP32,
GI_COMMON_WINOGRAD_F54_FP32,
GI_COMMON_WINOGRAD_F45_FP32,
GI_COMMON_WINOGRAD_F23_4X4_NCHW44_F32,
GI_COMMON_WINOGRAD_F63_4X4_NCHW44_F32,
GI_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
GI_COMMON_DIRECT_FP32,
GI_COMMON_DIRECT_STRD1_FP32,
GI_COMMON_DIRECT_STRD2_FP32,
GI_COMMON_DIRECT_NCHW44_FP32,
GI_COMMON_DIRECT_NCHW_NCHW44_FP32,
GI_COMMON_CHWNWISE_NCHW44_F32,

#if MEGDNN_X86
X86_DIRECT = 1 << 8,
@@ -248,20 +262,6 @@ public:
ARM_COMMON_DIRECT_STRD1_FP16,
ARM_COMMON_CHWNWISE_NCHW88_F16,
ARM_COMMON_DIRECT_NCHW88_FP16,
ARM_COMMON_WINOGRAD_F23_4X4_FP32,
ARM_COMMON_WINOGRAD_F63_FP32,
ARM_COMMON_WINOGRAD_F63_4X4_FP32,
ARM_COMMON_WINOGRAD_F54_FP32,
ARM_COMMON_WINOGRAD_F45_FP32,
ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32,
ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32,
ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
ARM_COMMON_DIRECT_FP32,
ARM_COMMON_DIRECT_STRD1_FP32,
ARM_COMMON_DIRECT_STRD2_FP32,
ARM_COMMON_DIRECT_NCHW44_FP32,
ARM_COMMON_DIRECT_NCHW_NCHW44_FP32,
ARM_COMMON_CHWNWISE_NCHW44_F32,
ARM_COMMON_DIRECT_STRD1_S8,
ARM_COMMON_DIRECT_STRD2_S8,
ARM_COMMON_DIRECT_NCHW44,
@@ -383,6 +383,23 @@ private:
class AlgoWinogradF32_4x4;
class AlgoWinogradQS8;
class AlgoWinogradQS8_8x8;

class AlgoFP32WinogradF23_4x4;
class AlgoFP32WinogradF63;
class AlgoFP32WinogradF63_4x4;
class AlgoFP32WinogradF54;
class AlgoFP32WinogradF45;
class AlgoFP32WinogradF23_4x4_NCHW44;
class AlgoFP32WinogradF63_4x4_NCHW44;
class AlgoFP32WinogradF73_4x4_NCHW44;

class AlgoF32Direct;
class AlgoF32DirectStride1;
class AlgoF32DirectStride2;
class AlgoF32DirectNCHWNCHW44;
class AlgoF32ChannelWiseNCHW44;
class AlgoF32DirectNCHW44;

class AlgoPack;

NCBKernSizeParam m_prev_selected_algo_sizep;


+ 1
- 204
dnn/test/arm_common/conv_bias.cpp View File

@@ -81,23 +81,6 @@ TEST_F(ARM_COMMON, CONV_BIAS_RECORD) {
}
}

TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward> checker(handle());

check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
}

TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());

check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
}

#define CONV_BIAS_MATMUL_QU8_MODE(MODE) \
using namespace conv_bias; \
std::vector<TestArg> args = get_quantized_args_with_nlmode(MODE); \
@@ -1015,14 +998,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) {
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_4x4) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4);
#else
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4);
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3);
@@ -1031,14 +1006,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) {
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_4x4) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4);
#else
benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4);
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) {
#if MEGDNN_AARCH64
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:5", handle(), 4);
@@ -1212,30 +1179,10 @@ void benchmark_winograd_nchw_vs_nchw44(
}
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44(
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle());
#else
benchmark_winograd_nchw_vs_nchw44(
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle());
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44(
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle());
#else
benchmark_winograd_nchw_vs_nchw44(
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle());
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44(
"AARCH64_F32_MK4_4x16:4:6", "ARM_COMMON_F32_GEMV_MK4:4:7", handle());
"AARCH64_F32_MK4_4x16:4:6", "FB_GI_F32_GEMV_MK4:4:7", handle());
#else
benchmark_winograd_nchw_vs_nchw44(
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:7", handle());
@@ -1609,156 +1556,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2) {
computations / used0, used1, computations / used1, used1 / used0);
}
}
TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;

constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1"));

auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;

param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout(
{{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 /
(1024 * 1024 * 1024) * 1e3;

auto used0 = benchmark0.exec(
{{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec(
{{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}

TEST_F(ARM_COMMON, BENCHMARK_CHANNEL_WISE_F32_STRIDE2_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 2;
param.stride_w = 2;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;

constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2"));

auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;

param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout(
{{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 /
(1024 * 1024 * 1024) * 1e3;

auto used0 = benchmark0.exec(
{{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec(
{{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}

TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;


+ 0
- 138
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -303,84 +303,6 @@ void checker_conv_bias_int8x8x32_multi(
}
}

/**********************************F32 direct************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
check_conv_bias(
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), handle(),
"F32DIRECT");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
//! k=7 s=1
check_conv_bias(
get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(
get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(
get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1), handle(),
"F32_CONV_NCHW44_DIRECT");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
check_conv_bias(
get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
check_conv_bias(
get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), handle(),
"F32STRD1");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
check_conv_bias(
get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), handle(),
"F32STRD2");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
check_conv_bias(
get_nchw44_conv_bias_args(
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false,
true),
handle(), "F32_CONV_NCHW_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
check_conv_bias(
get_nchw44_conv_bias_args(
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false,
true),
handle(), "F32_CONV_NCHW_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
check_conv_bias(
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

/**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
@@ -787,50 +709,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
#endif
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward> checker(handle());

check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward> checker(handle());
check_winograd(
"4:2:32", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(3);
Checker<ConvBiasForward> checker(handle());

check_winograd("1:6:32", checker, args);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward> checker(handle());

check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward> checker(handle());
check_winograd(
"4:6:16", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}

//! uncomment it when low precision mode is ok
#if 0
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
@@ -853,22 +731,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCE
}
#endif

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(4);
Checker<ConvBiasForward> checker(handle());

check_winograd("1:5:32", checker, args);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(5);
Checker<ConvBiasForward> checker(handle());

check_winograd("1:4:32", checker, args);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
using namespace conv_bias;



+ 0
- 201
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp View File

@@ -81,207 +81,6 @@ void benchmark_impl(
}
} // namespace

TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32) {
constexpr size_t RUNS = 50;

param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4);
bench_case(1, 32, 32, 200, 200, 3, 32);
bench_case(1, 32, 32, 128, 128, 3, 4);
bench_case(1, 32, 32, 128, 128, 3, 32);
bench_case(1, 32, 32, 100, 100, 3, 4);
bench_case(1, 32, 32, 100, 100, 3, 32);
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);

std::string algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR1) {
constexpr size_t RUNS = 50;
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4);
bench_case(1, 32, 32, 200, 200, 3, 32);
bench_case(1, 32, 32, 128, 128, 3, 4);
bench_case(1, 32, 32, 128, 128, 3, 32);
bench_case(1, 32, 32, 100, 100, 3, 4);
bench_case(1, 32, 32, 100, 100, 3, 32);
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);

std::string algo_name = "F32STRD1";
printf("Benchmark F32STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32STRD1";
printf("Benchmark F32STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF32_STR2) {
constexpr size_t RUNS = 50;

param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 2;
param.stride_w = 2;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group, size_t P, size_t S) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2);
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);

std::string algo_name = "F32STRD2";
printf("Benchmark F32STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32STRD2";
printf("Benchmark F32STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16) {
constexpr size_t RUNS = 50;


+ 0
- 84
dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp View File

@@ -20,91 +20,7 @@
using namespace megdnn;
using namespace test;
using namespace conv_bias;
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd(
"4:2:32", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(3);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd("1:6:32", checker, args);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_mk_packed_args();
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());

check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args =
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd(
"4:6:16", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(4);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd("1:5:32", checker, args);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(5);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd("1:4:32", checker, args);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> nchw44_args =
get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);

Checker<ConvBiasForward> checker(handle());

auto run = [&checker](
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
DType C_dtype, DType D_dtype, const float eps) {
for (auto&& arg : args) {
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
};

//! uncomment this when low precision mode is ok
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(),
// dtype::Float32(), dtype::Float32(), 1e-2f);

//! remove this when low precision mode is ok
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(),
dtype::Float32(), 1e-3f);
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1_WEIGHT_PREPROCESS) {
using namespace conv_bias;



+ 0
- 24
dnn/test/arm_common/matrix_mul.cpp View File

@@ -286,30 +286,6 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
run(M, K, N);
}

TEST_F(ARM_COMMON, FP32_GEMV_MK4) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;

checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4"));

checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K) {
Param param;
param.format = param::MatrixMul::Format::MK4;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param).execs({A, B, {}});
};

// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
}

TEST_F(ARM_COMMON, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(0);
checker.set_epsilon(1e-2);


+ 781
- 0
dnn/test/fallback/conv_bias.cpp View File

@@ -117,6 +117,30 @@ TEST_F(FALLBACK, CONV_BIAS_FORWARD_RECORD) {
}
}

TEST_F(FALLBACK, FP32_GEMV_MK4_GI) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;

checker.set_before_exec_callback(AlgoChecker<MatrixMul>("FB_GI_F32_GEMV_MK4"));

checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K) {
Param param;
param.format = param::MatrixMul::Format::MK4;
param.transposeA = false;
param.transposeB = false;
TensorShape A, B;
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param).execs({A, B, {}});
};

// N = 1
for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
}

std::vector<conv_bias::TestArg> get_conv_bias_args(
std::vector<size_t> kernel, std::vector<size_t> padv,
std::vector<param::ConvBias::NonlineMode> nlmodev, std::vector<size_t> stridev,
@@ -257,6 +281,189 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) {
dtype::Float32{}, dtype::Float32{}, "FALLBACK_NAIVE");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S2) {
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args(
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 2, false,
true),
handle(), "F32_CONV_NCHW_NCHW44");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_NCHW_NCHW44_F32_S1) {
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args(
{2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, ONLY_BR_BIASMODE, 1, false,
true),
handle(), "F32_CONV_NCHW_NCHW44");
}

std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
bool no_full_bias) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;

auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
size_t stride, NLMode nlmode, bool pad) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
if (pad) {
param.pad_h = kernel / 2;
param.pad_w = kernel / 2;
} else {
param.pad_h = 0;
param.pad_w = 0;
}
param.nonlineMode = nlmode;
param.format = param::ConvBias::Format::NCHW44;
param.sparse = param::ConvBias::Sparse::GROUP;

args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4}, TensorShape{});
if (!no_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{1, group, 1, 1, 4});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w, 4},
TensorShape{group, 1, 1, kernel, kernel, 4},
TensorShape{
n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1, 4});
}
};

std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
if (!no_nonlinemode) {
nonlinemode.emplace_back(NLMode::RELU);
nonlinemode.emplace_back(NLMode::H_SWISH);
}
for (size_t n : {1, 2}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 2, 4, 7, 16}) {
for (size_t size : {4, 6, 7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode, pad);
}
}
}
}
for (bool pad : {false}) {
for (size_t group : {1, 2, 7, 16}) {
for (size_t size : {7, 9, 20}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode, pad);
}
}
}
}
}
}
return args;
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
check_conv_bias(
get_nchw44_channel_wise_args({5}, 1, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false), handle(),
"F32_CHANNEL_WISE_NCHW44");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K7) {
//! k=7 s=1
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args(
{7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args(
{2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_NCHW44_S2) {
check_conv_bias(
conv_bias::get_nchw44_conv_bias_args(
{2, 3, 5, 7}, FULL_NLMODE, ONLY_BR_BIASMODE, 2),
handle(), "F32_CONV_NCHW44_DIRECT");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32) {
check_conv_bias(
conv_bias::get_conv_bias_args(
{1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR2) {
check_conv_bias(
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
handle(), "F32STRD2");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_DIRECT_FP32_STR1) {
check_conv_bias(
conv_bias::get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
handle(), "F32STRD1");
}

TEST_F(FALLBACK_MULTI_THREADS, CONVBIAS_GI_WINOGRAD_PREPROCESS_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> nchw44_args = conv_bias::get_nchw44_conv_bias_args(
{3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);

Checker<ConvBiasForward> checker(handle());

auto run = [&checker](
const std::vector<TestArg>& args, DType A_dtype, DType B_dtype,
DType C_dtype, DType D_dtype, const float eps) {
for (auto&& arg : args) {
checker.set_dtype(0, A_dtype)
.set_dtype(1, B_dtype)
.set_dtype(2, C_dtype)
.set_dtype(4, D_dtype)
.set_epsilon(eps)
.set_param(arg.param)
.execs({arg.src, arg.filter, arg.bias, {}, {}});
}
};

//! uncomment this when low precision mode is ok
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(),
// dtype::Float32(), dtype::Float32(), 1e-2f);

//! remove this when low precision mode is ok
run(nchw44_args, dtype::Float32(), dtype::Float32(), dtype::Float32(),
dtype::Float32(), 1e-3f);
}
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) {
using namespace conv_bias;
param::ConvBias cur_param;
@@ -273,6 +480,422 @@ TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_QUANTIZED) {
}

#if MEGDNN_WITH_BENCHMARK
namespace {
void benchmark_impl(
const param::ConvBias param,
std::vector<std::pair<SmallVector<TensorShape>, float>>& shapes_and_computation,
const std::string algo_name, size_t RUNS,
TaskExecutorConfig&& multi_thread_config,
TaskExecutorConfig&& single_thread_config, std::vector<DType>& data_type) {
std::vector<float> multi_thread_times, single_thread_times;
{
auto multi_thread_hanle = create_cpu_handle(0, true, &multi_thread_config);
auto benchmarker = Benchmarker<ConvBias>(multi_thread_hanle.get());
benchmarker.set_times(RUNS)
.set_display(false)
.set_param(param)
.set_dtype(0, data_type[0])
.set_dtype(1, data_type[1])
.set_dtype(2, data_type[2])
.set_dtype(4, data_type[3])
.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str()));
for (auto shape : shapes_and_computation) {
multi_thread_times.push_back(benchmarker.exec(shape.first) / RUNS);
}
}
{
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config);
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get());
benchmarker.set_times(RUNS)
.set_display(false)
.set_param(param)
.set_dtype(0, data_type[0])
.set_dtype(1, data_type[1])
.set_dtype(2, data_type[2])
.set_dtype(4, data_type[3])
.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str()));
for (auto shape : shapes_and_computation) {
single_thread_times.push_back(benchmarker.exec(shape.first) / RUNS);
}
}
printf("Benchmark : Multi threads %zu, ", multi_thread_config.nr_thread);
printf("core_ids:");
for (size_t i = 0; i < multi_thread_config.affinity_core_set.size(); i++) {
printf("%zu ", multi_thread_config.affinity_core_set[i]);
}
printf(", Single thread core_id %zu\n", single_thread_config.affinity_core_set[0]);
for (size_t i = 0; i < shapes_and_computation.size(); i++) {
auto shapes = shapes_and_computation[i];
printf("Bench case: ");
for (auto&& shape : shapes.first) {
printf("%s ", shape.to_string().c_str());
}
float computations = shapes.second;
printf("%zu threads gflops: %f,\n single thread gflops: "
"%f. spead up = %f, speedup/cores=%f\n",
multi_thread_config.nr_thread, computations / multi_thread_times[i],
computations / single_thread_times[i],
single_thread_times[i] / multi_thread_times[i],
single_thread_times[i] / multi_thread_times[i] /
multi_thread_config.nr_thread);
}
}
} // namespace

TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32) {
constexpr size_t RUNS = 50;

param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4);
bench_case(1, 32, 32, 200, 200, 3, 32);
bench_case(1, 32, 32, 128, 128, 3, 4);
bench_case(1, 32, 32, 128, 128, 3, 32);
bench_case(1, 32, 32, 100, 100, 3, 4);
bench_case(1, 32, 32, 100, 100, 3, 32);
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);

std::string algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32DIRECT";
printf("Benchmark F32DIRECT_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}

TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR1) {
constexpr size_t RUNS = 50;
param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, H, W}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4);
bench_case(1, 32, 32, 200, 200, 3, 32);
bench_case(1, 32, 32, 128, 128, 3, 4);
bench_case(1, 32, 32, 128, 128, 3, 32);
bench_case(1, 32, 32, 100, 100, 3, 4);
bench_case(1, 32, 32, 100, 100, 3, 32);
bench_case(1, 32, 32, 80, 80, 3, 4);
bench_case(1, 32, 32, 80, 80, 3, 32);

std::string algo_name = "F32STRD1";
printf("Benchmark F32STRD1_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32STRD1";
printf("Benchmark F32STRD1_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1);
bench_case(1, 32, 32, 128, 128, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}

TEST_F(FALLBACK_MULTI_THREADS, BENCHMARK_GI_CONVBIAS_DIRECTF32_STR2) {
constexpr size_t RUNS = 50;

param::ConvBias param;
param.nonlineMode = param::ConvBias::NonlineMode::RELU;
param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 2;
param.stride_w = 2;
param.sparse = param::ConvBias::Sparse::GROUP;

std::vector<std::pair<SmallVector<TensorShape>, float>> shapes_and_computation;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group, size_t P, size_t S) {
SmallVector<TensorShape> shapes{
{N, IC, H, W},
{group, OC / group, IC / group, FS, FS},
{1, OC, 1, 1},
{},
{N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1}};
TensorShape dst{N, OC, H, W};
float computations = ((IC / group) * FS * FS * dst.total_nr_elems() * 2 +
dst.total_nr_elems()) *
1e-6;
shapes_and_computation.push_back(std::make_pair(shapes, computations));
};

bench_case(1, 32, 32, 200, 200, 3, 4, 1, 2);
bench_case(1, 32, 32, 200, 200, 3, 32, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 4, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 32, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 4, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 32, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 4, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 32, 1, 2);

std::string algo_name = "F32STRD2";
printf("Benchmark F32STRD2_LARGE_GROUP algo\n");
std::vector<DType> data_type = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
shapes_and_computation.clear();

algo_name = "F32STRD2";
printf("Benchmark F32STRD2_SMALL_GROUP algo\n");
bench_case(1, 32, 32, 200, 200, 3, 1, 1, 2);
bench_case(1, 32, 32, 128, 128, 3, 1, 1, 2);
bench_case(1, 32, 32, 100, 100, 3, 1, 1, 2);
bench_case(1, 32, 32, 80, 80, 3, 1, 1, 2);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {4, {4, 5, 6, 7}}, {1, {7}},
data_type);
benchmark_impl(
param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}},
data_type);
}
TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE1_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;

constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD1"));

auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;

param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout(
{{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 /
(1024 * 1024 * 1024) * 1e3;

auto used0 = benchmark0.exec(
{{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec(
{{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}

TEST_F(FALLBACK, BENCHMARK_GI_CHANNEL_WISE_F32_STRIDE2_NCHW44) {
// have to remove preferred restrict in usable func before run the benchmark
using namespace conv_bias;
param::ConvBias param;
param.stride_h = 2;
param.stride_w = 2;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = NonlineMode::RELU;
param.sparse = param::ConvBias::Sparse::GROUP;

constexpr size_t RUN = 50;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_display(false);
benchmark0.set_param(param);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32STRD2"));

auto opr = handle()->create_operator<ConvBias>();
opr->param() = param;

param.format = param::ConvBias::Format::NCHW44;
Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_display(false);
benchmark1.set_param(param);
benchmark1.set_times(RUN);
benchmark1.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("F32_CHANNEL_WISE_NCHW44"));
auto run = [&](size_t group, size_t w, size_t h, size_t kernel) {
TensorLayout dst_layout;
opr->deduce_layout(
{{1, group * 4, h, w}, dtype::Int8()},
{{group * 4, 1, 1, kernel, kernel}, dtype::Int8()},
{{1, group * 4, 1, 1}, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * kernel * kernel * 2.0 /
(1024 * 1024 * 1024) * 1e3;

auto used0 = benchmark0.exec(
{{1, group * 4, h, w},
{group * 4, 1, 1, kernel, kernel},
{1, group * 4, 1, 1},
{},
{}}) /
RUN;
auto used1 = benchmark1.exec(
{{1, group, h, w, 4},
{group, 1, 1, kernel, kernel, 4},
{1, group, 1, 1, 4},
{},
{}}) /
RUN;
printf("group/h/w/kernel:%zu,%zu,%zu,%zu: nchw: %f ms %f Gflops "
"nchw44: "
"%f ms %f GFlops "
"speedup: %f\n",
group, h, w, kernel, used0, computations / used0, used1,
computations / used1, used0 / used1);
};
for (size_t group : {8, 16, 32, 64}) {
for (size_t kerenl : {2, 3, 5}) {
run(group, 112, 112, kerenl);
run(group, 56, 56, kerenl);
run(group, 48, 48, kerenl);
run(group, 28, 28, kerenl);
run(group, 14, 14, kerenl);
}
}
run(8, 112, 112, 3);
run(32, 56, 56, 3);
run(64, 28, 28, 3);
run(128, 14, 14, 3);
}

TEST_F(FALLBACK, BENCHMARK_CONVBIAS) {
constexpr size_t RUNS = 10;
param::ConvBias param;
@@ -320,6 +943,164 @@ TEST_F(FALLBACK, BENCHMARK_CONVBIAS) {
}
}
}

TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_4x4) {
#if MEGDNN_AARCH64
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:2", handle(), 3, 4);
#elif MEGDNN_ARMV7
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:2", handle(), 3, 4);
#else
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:2", handle(), 3, 4);
#endif
}

void benchmark_winograd_nchw_vs_nchw44(
const char* algo_name0, const char* algo_name1, Handle* handle) {
using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args_nchw44;
std::vector<conv_bias::TestArg> args_nchw;

auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t group,
NLMode nlmode) {
param::ConvBias param;
param.format = param::ConvBias::Format::NCHW44;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = 1;
param.pad_w = 1;
param.nonlineMode = nlmode;

if (group == 1) {
param.sparse = param::ConvBias::Sparse::DENSE;
args_nchw44.emplace_back(
param, TensorShape{n, ic / 4, h, w, 4},
TensorShape{oc / 4, ic / 4, 3, 3, 4, 4}, TensorShape{});
param.format = param::ConvBias::Format::NCHW;
args_nchw.emplace_back(
param, TensorShape{n, ic, h, w}, TensorShape{oc, ic, 3, 3},
TensorShape{});
} else {
auto oc_per_group = oc / group;
auto ic_per_group = ic / group;
param.sparse = param::ConvBias::Sparse::GROUP;
args_nchw44.emplace_back(
param, TensorShape{n, ic_per_group / 4, h, w, 4},
TensorShape{group, oc_per_group / 4, ic_per_group / 4, 3, 3, 4, 4},
TensorShape{});
param.format = param::ConvBias::Format::NCHW;
args_nchw.emplace_back(
param, TensorShape{n, ic, h, w},
TensorShape{group, oc_per_group, ic_per_group, 3, 3},
TensorShape{});
}
};

std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
for (auto nlmode : nonlinemode)
for (size_t n : {1})
for (size_t group = 1; group <= 1; ++group) {
pack(n, 512, 512, 15, 15, group, nlmode);
pack(n, 512, 256, 15, 15, group, nlmode);
pack(n, 256, 256, 29, 29, group, nlmode);
pack(n, 256, 128, 29, 29, group, nlmode);
pack(n, 128, 128, 57, 57, group, nlmode);
pack(n, 128, 64, 57, 57, group, nlmode);
pack(n, 24, 24, 224, 224, group, nlmode);
pack(n, 64, 24, 123, 123, group, nlmode);
pack(n, 64, 64, 56, 56, group, nlmode);
pack(n, 128, 128, 28, 28, group, nlmode);
pack(n, 256, 256, 14, 14, group, nlmode);
pack(n, 512, 512, 7, 7, group, nlmode);
}

using namespace conv_bias;
constexpr size_t RUN = 10;
Benchmarker<ConvBias> benchmark_winograd_nchw(handle);
benchmark_winograd_nchw.set_display(false);
benchmark_winograd_nchw.set_times(RUN);

Benchmarker<ConvBias> benchmark_winograd_nchw44(handle);
benchmark_winograd_nchw44.set_display(false);
benchmark_winograd_nchw44.set_times(RUN);

std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0);
std::string winograd_nchw44_algo_name = ssprintf("WINOGRAD_NCHW44:%s", algo_name1);

for (size_t i = 0; i < args_nchw.size(); ++i) {
auto arg_nchw = args_nchw[i];
auto arg_nchw44 = args_nchw44[i];

TensorLayout dst_layout;
auto opr = handle->create_operator<ConvBias>();
opr->param() = arg_nchw.param;
opr->deduce_layout(
{arg_nchw.src, dtype::Float32()}, {arg_nchw.filter, dtype::Float32()},
{arg_nchw.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg_nchw.filter[1] *
arg_nchw.filter[2] * arg_nchw.filter[3] * 2.0 /
(1024 * 1024 * 1024) * 1e3;

benchmark_winograd_nchw.set_param(arg_nchw.param);
auto nchw_used = algo_benchmark<ConvBias>(
benchmark_winograd_nchw,
{arg_nchw.src, arg_nchw.filter, {}, {}, {}},
winograd_nchw_algo_name.c_str()) /
RUN;

benchmark_winograd_nchw44.set_param(arg_nchw44.param);
auto nchw44_used = algo_benchmark<ConvBias>(
benchmark_winograd_nchw44,
{arg_nchw44.src, arg_nchw44.filter, {}, {}, {}},
winograd_nchw44_algo_name.c_str()) /
RUN;

printf("%s %s: nchw: %f ms %f Gflops nchw44: %f ms %f GFlops "
"speedup: "
"%f\n",
arg_nchw.src.to_string().c_str(), arg_nchw.filter.to_string().c_str(),
nchw_used, computations / nchw_used, nchw44_used,
computations / nchw44_used, nchw_used / nchw44_used);
}
}

TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44(
"AARCH64_F32_MK4_4x16:4:2", "AARCH64_F32_MK4_4x16:4:2", handle());
#elif MEGDNN_ARMV7
benchmark_winograd_nchw_vs_nchw44(
"ARMV7_F32_MK4_4x8:4:2", "ARMV7_F32_MK4_4x8:4:2", handle());
#else
benchmark_winograd_nchw_vs_nchw44(
"FB_GI_F32_MK4_4x8:4:2", "FB_GI_F32_MK4_4x8:4:2", handle());
#endif
}

TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_4x4) {
#if MEGDNN_AARCH64
conv_bias::benchmark_winograd("WINOGRAD:AARCH64_F32_MK4_4x16:4:6", handle(), 3, 4);
#elif MEGDNN_ARMV7
conv_bias::benchmark_winograd("WINOGRAD:ARMV7_F32_MK4_4x8:4:6", handle(), 3, 4);
#else
conv_bias::benchmark_winograd("WINOGRAD:FB_GI_F32_MK4_4x8:4:6", handle(), 3, 4);
#endif
}

TEST_F(FALLBACK, BENCHMARK_GI_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44(
"AARCH64_F32_MK4_4x16:4:6", "AARCH64_F32_MK4_4x16:4:6", handle());
#elif MEGDNN_ARMV7
benchmark_winograd_nchw_vs_nchw44(
"ARMV7_F32_MK4_4x8:4:6", "ARMV7_F32_MK4_4x8:4:6", handle());
#else
benchmark_winograd_nchw_vs_nchw44(
"FB_GI_F32_MK4_4x8:4:6", "FB_GI_F32_MK4_4x8:4:6", handle());
#endif
}

#endif

} // namespace test


Loading…
Cancel
Save