GitOrigin-RevId: 8569c9dfc6
release-0.6
@@ -32,29 +32,37 @@ constexpr size_t pack_size = 4; | |||||
struct InputTransformF23_NCHW44 { | struct InputTransformF23_NCHW44 { | ||||
template <bool inner> | template <bool inner> | ||||
static void prepare(const float* input, float* patch, float* patchT, | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | |||||
size_t ic, size_t IC) { | |||||
MEGDNN_MARK_USED_VAR(patch); | |||||
static void transform(float* patchT, const float* input, | |||||
float* input_transform_buf, size_t ih_start, | |||||
size_t iw_start, size_t IH, size_t IW, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
size_t IW4 = IW * pack_size; | size_t IW4 = IW * pack_size; | ||||
size_t iw4_start = iw_start * pack_size; | |||||
size_t icb = ic / pack_size; | size_t icb = ic / pack_size; | ||||
size_t iw4_start = iw_start * pack_size; | |||||
size_t ICB = IC / pack_size; | |||||
#define cb(m, n) Vector<float, 4> d##m##n; | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | |||||
if (!(inner && ic + pack_size < IC)) { | if (!(inner && ic + pack_size < IC)) { | ||||
memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); | memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); | ||||
} | } | ||||
if (inner) { | if (inner) { | ||||
MEGDNN_MARK_USED_VAR(patchT); | |||||
const float* input_ptr = | const float* input_ptr = | ||||
input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | input + icb * IH * IW4 + ih_start * IW4 + iw4_start; | ||||
for (size_t ih = 0; ih < alpha; ih++) { | |||||
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | |||||
#define cb(i) vst1q_f32(patchT + ih * alpha * pack_size + i * pack_size, v##i); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n); | |||||
UNROLL_CALL_RAW(4, cb, 0); | |||||
input_ptr += IW4; | |||||
UNROLL_CALL_RAW(4, cb, 1); | |||||
input_ptr += IW4; | |||||
UNROLL_CALL_RAW(4, cb, 2); | |||||
input_ptr += IW4; | |||||
UNROLL_CALL_RAW(4, cb, 3); | |||||
#undef cb | #undef cb | ||||
input_ptr += IW4; | |||||
} | |||||
} else { | } else { | ||||
int ih0_act = std::max<int>(ih_start, 0), | int ih0_act = std::max<int>(ih_start, 0), | ||||
ih1_act = std::min<int>(ih_start + alpha, IH), | ih1_act = std::min<int>(ih_start + alpha, IH), | ||||
@@ -71,19 +79,12 @@ struct InputTransformF23_NCHW44 { | |||||
src); | src); | ||||
} | } | ||||
} | } | ||||
} | |||||
} | |||||
static void transform(const float* patchT, float* input_transform_buf, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | |||||
size_t IC) { | |||||
// BT * d * B | |||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = Vector<float, 4>::load( \ | |||||
patchT + m * alpha * pack_size + n * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#define cb(m, n) \ | |||||
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + \ | |||||
n * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | #undef cb | ||||
} | |||||
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 | ||||
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 | ||||
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 | ||||
@@ -106,8 +107,6 @@ struct InputTransformF23_NCHW44 { | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | UNROLL_CALL_NOWRAPPER(4, cb); | ||||
#undef cb | #undef cb | ||||
size_t ICB = IC / 4; | |||||
size_t icb = ic / 4; | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
d##m##n.save(input_transform_buf + \ | d##m##n.save(input_transform_buf + \ | ||||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | ||||
@@ -273,7 +272,6 @@ void winograd_F23_mk4_f_nchw44::input(const float* input, | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | // OW = IW + 2 * PW - KERNEL_SIZE + 1 | ||||
auto units_w = | auto units_w = | ||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | ||||
float* patch = transform_mid_buf; | |||||
float* patchT = transform_mid_buf + 4 * alpha * alpha; | float* patchT = transform_mid_buf + 4 * alpha * alpha; | ||||
for (size_t ic = 0; ic < IC; ic += 4) { | for (size_t ic = 0; ic < IC; ic += 4) { | ||||
@@ -285,20 +283,13 @@ void winograd_F23_mk4_f_nchw44::input(const float* input, | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | ||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | ||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | ||||
InputTransformF23_NCHW44::prepare<true>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransformF23_NCHW44::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
InputTransformF23_NCHW44::transform<true>( | |||||
patchT, input, input_transform_buf, ih_start, iw_start, | |||||
IH, IW, unit_idx, nr_units_in_tile, ic, IC); | |||||
} else { | } else { | ||||
InputTransformF23_NCHW44::prepare<false>(input, patch, patchT, | |||||
ih_start, iw_start, IH, | |||||
IW, ic, IC); | |||||
InputTransformF23_NCHW44::transform(patchT, input_transform_buf, | |||||
unit_idx, nr_units_in_tile, | |||||
ic, IC); | |||||
InputTransformF23_NCHW44::transform<false>( | |||||
patchT, input, input_transform_buf, ih_start, iw_start, | |||||
IH, IW, unit_idx, nr_units_in_tile, ic, IC); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -311,9 +302,21 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, | |||||
size_t OW, size_t oc_start, | size_t OW, size_t oc_start, | ||||
size_t oc_end, size_t unit_start_idx, | size_t oc_end, size_t unit_start_idx, | ||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransformF23_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||||
__VA_ARGS__); | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { \ | |||||
size_t oc_index = oc - oc_start; \ | |||||
rep(unit_idx, nr_units_in_tile) { \ | |||||
size_t index = unit_start_idx + unit_idx; \ | |||||
auto nh = index / units_w; \ | |||||
auto nw = index % units_w; \ | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ | |||||
OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \ | |||||
output_transform_buf, bias, output, transform_mid_buf, \ | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, \ | |||||
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); \ | |||||
} \ | |||||
} | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | ||||
constexpr size_t pack_size = 4; | constexpr size_t pack_size = 4; | ||||
@@ -323,22 +326,8 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, | |||||
oc_end % pack_size == 0, | oc_end % pack_size == 0, | ||||
"NCHW44 Winograd filter transform requires OC is times of 4"); | "NCHW44 Winograd filter transform requires OC is times of 4"); | ||||
for (size_t oc = oc_start; oc < oc_end; oc += 4) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, | |||||
float, bmode, nonline_mode, output_transform_buf, bias, | |||||
output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||||
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, | |||||
src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, | |||||
cb, float, float, bmode, nonline_mode); | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -31,6 +31,8 @@ namespace { | |||||
constexpr size_t alpha = 6 + 3 - 1; | constexpr size_t alpha = 6 + 3 - 1; | ||||
constexpr size_t pack_size = 4; | constexpr size_t pack_size = 4; | ||||
constexpr float input_parameters[12] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f, | |||||
2.0f, 4.0f, 5.0f, 0.0f, 0.0f, 0.0f}; | |||||
struct InputTransformF63_NCHW44 { | struct InputTransformF63_NCHW44 { | ||||
template <bool inner> | template <bool inner> | ||||
@@ -80,12 +82,14 @@ struct InputTransformF63_NCHW44 { | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t ic, | size_t unit_idx, size_t nr_units_in_tile, size_t ic, | ||||
size_t IC) { | size_t IC) { | ||||
// BT * d * B | // BT * d * B | ||||
#define cb(m, n) \ | |||||
Vector<float, 4> d##m##n = Vector<float, 4>::load( \ | |||||
patchT + m * alpha * pack_size + n * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | |||||
size_t ICB = IC / pack_size; | |||||
size_t icb = ic / pack_size; | |||||
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; | |||||
float32x4_t v0 = vld1q_f32(input_parameters + 0); | |||||
float32x4_t v1 = vld1q_f32(input_parameters + 4); | |||||
float32x4_t v2 = vld1q_f32(input_parameters + 8); | |||||
//! B | //! B | ||||
//! 1 0 0 0 0 0 0 0 | //! 1 0 0 0 0 0 0 0 | ||||
@@ -96,49 +100,147 @@ struct InputTransformF63_NCHW44 { | |||||
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 | //! 0 1 -1 2 -2 0.5 -0.5 -5.25 | ||||
//! -1 1 1 1 1 1 1 0 | //! -1 1 1 1 1 1 1 0 | ||||
//! 0 0 0 0 0 0 0 1 | //! 0 0 0 0 0 0 0 1 | ||||
#define cb(m) \ | |||||
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ | |||||
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ | |||||
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ | |||||
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ | |||||
d5##m * 2.f + d6##m; \ | |||||
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ | |||||
d4##m * 1.25f - d5##m * 2.f + d6##m; \ | |||||
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ | |||||
d5##m * 0.5f + d6##m; \ | |||||
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | |||||
#define cb(m) \ | |||||
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ | |||||
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ | |||||
(t##m##3 + t##m##4) * 4.25f; \ | |||||
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ | |||||
(t##m##3 - t##m##4) * 4.25f; \ | |||||
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ | |||||
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ | |||||
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ | |||||
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ | |||||
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ | |||||
t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ | |||||
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ | |||||
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#define cb(i) \ | |||||
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ | |||||
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ | |||||
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ | |||||
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ | |||||
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ | |||||
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ | |||||
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ | |||||
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ | |||||
auto t##i##1 = d6; \ | |||||
auto t##i##2 = d6; \ | |||||
auto t##i##3 = d6; \ | |||||
auto t##i##4 = d6; \ | |||||
auto t##i##5 = d6; \ | |||||
auto t##i##6 = d6; \ | |||||
t##i##0 = t##i##0 - d6; \ | |||||
t##i##1 = t##i##1 + d1; \ | |||||
t##i##2 = t##i##2 - d1; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ | |||||
t##i##7 = t##i##7 - d1; \ | |||||
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ | |||||
t##i##1 = t##i##1 + d2; \ | |||||
t##i##2 = t##i##2 + d2; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ | |||||
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ | |||||
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ | |||||
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ | |||||
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ | |||||
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ | |||||
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ | |||||
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ | |||||
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ | |||||
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ | |||||
t##i##1 = t##i##1 + d5; \ | |||||
t##i##2 = t##i##2 - d5; \ | |||||
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ | |||||
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ | |||||
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ | |||||
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ | |||||
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | #undef cb | ||||
size_t ICB = IC / pack_size; | |||||
size_t icb = ic / pack_size; | |||||
#define cb(m, n) \ | |||||
d##m##n.save(input_transform_buf + \ | |||||
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + unit_idx * pack_size); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) | |||||
#define cb(i) \ | |||||
d0 = t0##i; \ | |||||
d1 = t6##i; \ | |||||
d2 = t6##i; \ | |||||
d3 = t6##i; \ | |||||
d4 = t6##i; \ | |||||
d5 = t6##i; \ | |||||
d6 = t6##i; \ | |||||
d7 = t7##i; \ | |||||
d0 = d0 - t6##i; \ | |||||
d1 = d1 + t1##i; \ | |||||
d2 = d2 - t1##i; \ | |||||
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ | |||||
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ | |||||
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ | |||||
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ | |||||
d7 = d7 - t1##i; \ | |||||
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ | |||||
d1 = d1 + t2##i; \ | |||||
d2 = d2 + t2##i; \ | |||||
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ | |||||
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ | |||||
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ | |||||
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ | |||||
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ | |||||
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ | |||||
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ | |||||
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ | |||||
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ | |||||
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ | |||||
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ | |||||
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ | |||||
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ | |||||
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ | |||||
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ | |||||
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ | |||||
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ | |||||
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ | |||||
d1 = d1 + t5##i; \ | |||||
d2 = d2 - t5##i; \ | |||||
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ | |||||
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ | |||||
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ | |||||
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ | |||||
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d0); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d1); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d2); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d3); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d4); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d5); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d6); \ | |||||
vst1q_f32(input_transform_buf + \ | |||||
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ | |||||
icb * nr_units_in_tile * pack_size + \ | |||||
unit_idx * pack_size, \ | |||||
d7); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | #undef cb | ||||
} | } | ||||
}; | }; | ||||
@@ -178,7 +280,7 @@ struct OutputTransformF63_NCHW44 { | |||||
* 1 -2 4 -8 16 -32 | * 1 -2 4 -8 16 -32 | ||||
* 1 0.5 0.25 0.125 0.0625 0.03125 | * 1 0.5 0.25 0.125 0.0625 0.03125 | ||||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | * 1 -0.5 0.25 -0.125 0.0625 -0.03125 | ||||
* 0 0.0 0 0 0 1 | |||||
* 0 0 0 0 0 1 | |||||
*/ | */ | ||||
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | ||||
@@ -378,28 +480,33 @@ void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf, | |||||
size_t OW, size_t oc_start, | size_t OW, size_t oc_start, | ||||
size_t oc_end, size_t unit_start_idx, | size_t oc_end, size_t unit_start_idx, | ||||
size_t nr_units_in_tile) { | size_t nr_units_in_tile) { | ||||
constexpr size_t pack_size = 4; | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||||
__VA_ARGS__); | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ | |||||
size_t oc_index = oc - oc_start; \ | |||||
rep(unit_idx, nr_units_in_tile) { \ | |||||
size_t index = unit_start_idx + unit_idx; \ | |||||
auto nh = index / units_w; \ | |||||
auto nw = index % units_w; \ | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ | |||||
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ | |||||
transform(output_transform_buf, bias, output, \ | |||||
transform_mid_buf, oh_start, ow_start, OH, OW, \ | |||||
oc_start, oc_end, oc_index, unit_idx, \ | |||||
nr_units_in_tile, src_dtype, dst_dtype); \ | |||||
} \ | |||||
} | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | ||||
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, | |||||
bmode, nonline_mode, output_transform_buf, bias, output, | |||||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, | |||||
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, | |||||
dst_dtype); | |||||
} | |||||
} | |||||
constexpr size_t pack_size = 4; | |||||
size_t OC = oc_end - oc_start; | |||||
megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && | |||||
oc_end % pack_size == 0, | |||||
"NCHW44 Winograd filter transform requires OC is times of 4"); | |||||
DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F63_mk4, cb, | |||||
float, float, bmode, nonline_mode); | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -538,10 +538,43 @@ struct Vfmaq_laneq_f32_armv7<3> { | |||||
return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); | return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); | ||||
} | } | ||||
}; | }; | ||||
template <int lane> | |||||
struct Vfmsq_laneq_f32_armv7 { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv7<0> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
return vmlsq_lane_f32(a, b, vget_low_f32(v), 0); | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv7<1> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
return vmlsq_lane_f32(a, b, vget_low_f32(v), 1); | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv7<2> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
return vmlsq_lane_f32(a, b, vget_high_f32(v), 0); | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv7<3> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
return vmlsq_lane_f32(a, b, vget_high_f32(v), 1); | |||||
} | |||||
}; | |||||
} // namespace | } // namespace | ||||
#define vfmaq_laneq_f32(a, b, v, lane) \ | #define vfmaq_laneq_f32(a, b, v, lane) \ | ||||
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v) | Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v) | ||||
#define vfmsq_laneq_f32(a, b, v, lane) \ | |||||
Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v) | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
namespace { | namespace { | ||||
template <int lane> | template <int lane> | ||||
@@ -582,7 +615,6 @@ struct Vdotq_laneq_s32_armv7<3> { | |||||
//! GCC split fmla with lane to dup+fmla when version < 9 | //! GCC split fmla with lane to dup+fmla when version < 9 | ||||
//! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 | //! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 | ||||
#if !defined(__clang__) && __GNUC__ < 9 | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
namespace { | namespace { | ||||
@@ -630,13 +662,59 @@ struct Vfmaq_laneq_f32_armv8<3> { | |||||
return a; | return a; | ||||
} | } | ||||
}; | }; | ||||
template <int lane> | |||||
struct Vfmsq_laneq_f32_armv8 { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv8<0> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.s[0]\n" | |||||
: "+w"(a) | |||||
: "w"(b), "w"(v) | |||||
:); | |||||
return a; | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv8<1> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.s[1]\n" | |||||
: "+w"(a) | |||||
: "w"(b), "w"(v) | |||||
:); | |||||
return a; | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv8<2> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.s[2]\n" | |||||
: "+w"(a) | |||||
: "w"(b), "w"(v) | |||||
:); | |||||
return a; | |||||
} | |||||
}; | |||||
template <> | |||||
struct Vfmsq_laneq_f32_armv8<3> { | |||||
__ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.s[3]\n" | |||||
: "+w"(a) | |||||
: "w"(b), "w"(v) | |||||
:); | |||||
return a; | |||||
} | |||||
}; | |||||
} // namespace | } // namespace | ||||
#undef vfmaq_laneq_f32 | #undef vfmaq_laneq_f32 | ||||
#define vfmaq_laneq_f32(a, b, v, lane) \ | #define vfmaq_laneq_f32(a, b, v, lane) \ | ||||
Vfmaq_laneq_f32_armv8<lane>::impl(a, b, v) | Vfmaq_laneq_f32_armv8<lane>::impl(a, b, v) | ||||
#endif | |||||
#undef vfmsq_laneq_f32 | |||||
#define vfmsq_laneq_f32(a, b, v, lane) \ | |||||
Vfmsq_laneq_f32_armv8<lane>::impl(a, b, v) | |||||
#endif | #endif | ||||
__ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) { | __ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) { | ||||
@@ -678,6 +756,16 @@ __ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) { | |||||
return vmovl_s8(vld1_dup_s8(ptr)); | return vmovl_s8(vld1_dup_s8(ptr)); | ||||
} | } | ||||
//! we add this because we found that cpu=aarch64_android cann't compile fmsq into fmls. | |||||
//! it use dup+fmla instead | |||||
__ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.4s\n" | |||||
: "+w"(a) | |||||
: "w"(b), "w"(v) | |||||
:); | |||||
return a; | |||||
} | |||||
#undef __ai | #undef __ai | ||||
#pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
@@ -791,8 +791,8 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { | |||||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | ||||
for (auto nlmode : nonlinemode) | for (auto nlmode : nonlinemode) | ||||
for (size_t n : {1, 2}) | |||||
for (size_t group = 1; group <= 2; ++group) { | |||||
for (size_t n : {1}) | |||||
for (size_t group = 1; group <= 1; ++group) { | |||||
pack(n, 512, 512, 15, 15, group, nlmode); | pack(n, 512, 512, 15, 15, group, nlmode); | ||||
pack(n, 512, 256, 15, 15, group, nlmode); | pack(n, 512, 256, 15, 15, group, nlmode); | ||||
pack(n, 256, 256, 29, 29, group, nlmode); | pack(n, 256, 256, 29, 29, group, nlmode); | ||||