|
|
@@ -31,6 +31,8 @@ namespace { |
|
|
|
|
|
|
|
constexpr size_t alpha = 6 + 3 - 1; |
|
|
|
constexpr size_t pack_size = 4; |
|
|
|
constexpr float input_parameters[12] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f, |
|
|
|
2.0f, 4.0f, 5.0f, 0.0f, 0.0f, 0.0f}; |
|
|
|
|
|
|
|
struct InputTransformF63_NCHW44 { |
|
|
|
template <bool inner> |
|
|
@@ -80,12 +82,14 @@ struct InputTransformF63_NCHW44 { |
|
|
|
size_t unit_idx, size_t nr_units_in_tile, size_t ic, |
|
|
|
size_t IC) { |
|
|
|
// BT * d * B |
|
|
|
#define cb(m, n) \ |
|
|
|
Vector<float, 4> d##m##n = Vector<float, 4>::load( \ |
|
|
|
patchT + m * alpha * pack_size + n * pack_size); |
|
|
|
|
|
|
|
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); |
|
|
|
#undef cb |
|
|
|
size_t ICB = IC / pack_size; |
|
|
|
size_t icb = ic / pack_size; |
|
|
|
|
|
|
|
float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; |
|
|
|
float32x4_t v0 = vld1q_f32(input_parameters + 0); |
|
|
|
float32x4_t v1 = vld1q_f32(input_parameters + 4); |
|
|
|
float32x4_t v2 = vld1q_f32(input_parameters + 8); |
|
|
|
|
|
|
|
//! B |
|
|
|
//! 1 0 0 0 0 0 0 0 |
|
|
@@ -96,49 +100,147 @@ struct InputTransformF63_NCHW44 { |
|
|
|
//! 0 1 -1 2 -2 0.5 -0.5 -5.25 |
|
|
|
//! -1 1 1 1 1 1 1 0 |
|
|
|
//! 0 0 0 0 0 0 0 1 |
|
|
|
#define cb(m) \ |
|
|
|
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ |
|
|
|
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ |
|
|
|
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ |
|
|
|
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ |
|
|
|
d5##m * 2.f + d6##m; \ |
|
|
|
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ |
|
|
|
d4##m * 1.25f - d5##m * 2.f + d6##m; \ |
|
|
|
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ |
|
|
|
d5##m * 0.5f + d6##m; \ |
|
|
|
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ |
|
|
|
d5##m * 0.5f + d6##m; \ |
|
|
|
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; |
|
|
|
|
|
|
|
UNROLL_CALL_NOWRAPPER(8, cb); |
|
|
|
#undef cb |
|
|
|
|
|
|
|
#define cb(m) \ |
|
|
|
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ |
|
|
|
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ |
|
|
|
(t##m##3 + t##m##4) * 4.25f; \ |
|
|
|
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ |
|
|
|
(t##m##3 - t##m##4) * 4.25f; \ |
|
|
|
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ |
|
|
|
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ |
|
|
|
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ |
|
|
|
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ |
|
|
|
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ |
|
|
|
t##m##5 * 0.5f + t##m##6; \ |
|
|
|
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ |
|
|
|
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ |
|
|
|
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; |
|
|
|
|
|
|
|
UNROLL_CALL_NOWRAPPER(8, cb); |
|
|
|
#define cb(i) \ |
|
|
|
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ |
|
|
|
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ |
|
|
|
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ |
|
|
|
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ |
|
|
|
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ |
|
|
|
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ |
|
|
|
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ |
|
|
|
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ |
|
|
|
auto t##i##1 = d6; \ |
|
|
|
auto t##i##2 = d6; \ |
|
|
|
auto t##i##3 = d6; \ |
|
|
|
auto t##i##4 = d6; \ |
|
|
|
auto t##i##5 = d6; \ |
|
|
|
auto t##i##6 = d6; \ |
|
|
|
t##i##0 = t##i##0 - d6; \ |
|
|
|
t##i##1 = t##i##1 + d1; \ |
|
|
|
t##i##2 = t##i##2 - d1; \ |
|
|
|
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ |
|
|
|
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ |
|
|
|
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ |
|
|
|
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ |
|
|
|
t##i##7 = t##i##7 - d1; \ |
|
|
|
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ |
|
|
|
t##i##1 = t##i##1 + d2; \ |
|
|
|
t##i##2 = t##i##2 + d2; \ |
|
|
|
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ |
|
|
|
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ |
|
|
|
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ |
|
|
|
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ |
|
|
|
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ |
|
|
|
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ |
|
|
|
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ |
|
|
|
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ |
|
|
|
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ |
|
|
|
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ |
|
|
|
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ |
|
|
|
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ |
|
|
|
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ |
|
|
|
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ |
|
|
|
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ |
|
|
|
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ |
|
|
|
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ |
|
|
|
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ |
|
|
|
t##i##1 = t##i##1 + d5; \ |
|
|
|
t##i##2 = t##i##2 - d5; \ |
|
|
|
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ |
|
|
|
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ |
|
|
|
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ |
|
|
|
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ |
|
|
|
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); |
|
|
|
UNROLL_CALL_RAW(8, cb); |
|
|
|
#undef cb |
|
|
|
|
|
|
|
size_t ICB = IC / pack_size; |
|
|
|
size_t icb = ic / pack_size; |
|
|
|
#define cb(m, n) \ |
|
|
|
d##m##n.save(input_transform_buf + \ |
|
|
|
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + unit_idx * pack_size); |
|
|
|
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) |
|
|
|
#define cb(i) \ |
|
|
|
d0 = t0##i; \ |
|
|
|
d1 = t6##i; \ |
|
|
|
d2 = t6##i; \ |
|
|
|
d3 = t6##i; \ |
|
|
|
d4 = t6##i; \ |
|
|
|
d5 = t6##i; \ |
|
|
|
d6 = t6##i; \ |
|
|
|
d7 = t7##i; \ |
|
|
|
d0 = d0 - t6##i; \ |
|
|
|
d1 = d1 + t1##i; \ |
|
|
|
d2 = d2 - t1##i; \ |
|
|
|
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ |
|
|
|
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ |
|
|
|
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ |
|
|
|
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ |
|
|
|
d7 = d7 - t1##i; \ |
|
|
|
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ |
|
|
|
d1 = d1 + t2##i; \ |
|
|
|
d2 = d2 + t2##i; \ |
|
|
|
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ |
|
|
|
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ |
|
|
|
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ |
|
|
|
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ |
|
|
|
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ |
|
|
|
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ |
|
|
|
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ |
|
|
|
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ |
|
|
|
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ |
|
|
|
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ |
|
|
|
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ |
|
|
|
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ |
|
|
|
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ |
|
|
|
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ |
|
|
|
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ |
|
|
|
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ |
|
|
|
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ |
|
|
|
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ |
|
|
|
d1 = d1 + t5##i; \ |
|
|
|
d2 = d2 - t5##i; \ |
|
|
|
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ |
|
|
|
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ |
|
|
|
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ |
|
|
|
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ |
|
|
|
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d0); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d1); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d2); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d3); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d4); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d5); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d6); \ |
|
|
|
vst1q_f32(input_transform_buf + \ |
|
|
|
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ |
|
|
|
icb * nr_units_in_tile * pack_size + \ |
|
|
|
unit_idx * pack_size, \ |
|
|
|
d7); |
|
|
|
UNROLL_CALL_RAW(8, cb); |
|
|
|
#undef cb |
|
|
|
} |
|
|
|
}; |
|
|
@@ -178,7 +280,7 @@ struct OutputTransformF63_NCHW44 { |
|
|
|
* 1 -2 4 -8 16 -32 |
|
|
|
* 1 0.5 0.25 0.125 0.0625 0.03125 |
|
|
|
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 |
|
|
|
* 0 0.0 0 0 0 1 |
|
|
|
* 0 0 0 0 0 1 |
|
|
|
*/ |
|
|
|
|
|
|
|
Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; |
|
|
@@ -378,28 +480,33 @@ void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf, |
|
|
|
size_t OW, size_t oc_start, |
|
|
|
size_t oc_end, size_t unit_start_idx, |
|
|
|
size_t nr_units_in_tile) { |
|
|
|
constexpr size_t pack_size = 4; |
|
|
|
#define cb(_bmode, _nonline_op, ...) \ |
|
|
|
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ |
|
|
|
__VA_ARGS__); |
|
|
|
#define cb(_bmode, _nonline_op, ...) \ |
|
|
|
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ |
|
|
|
size_t oc_index = oc - oc_start; \ |
|
|
|
rep(unit_idx, nr_units_in_tile) { \ |
|
|
|
size_t index = unit_start_idx + unit_idx; \ |
|
|
|
auto nh = index / units_w; \ |
|
|
|
auto nw = index % units_w; \ |
|
|
|
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ |
|
|
|
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ |
|
|
|
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ |
|
|
|
transform(output_transform_buf, bias, output, \ |
|
|
|
transform_mid_buf, oh_start, ow_start, OH, OW, \ |
|
|
|
oc_start, oc_end, oc_index, unit_idx, \ |
|
|
|
nr_units_in_tile, src_dtype, dst_dtype); \ |
|
|
|
} \ |
|
|
|
} |
|
|
|
|
|
|
|
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); |
|
|
|
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { |
|
|
|
size_t oc_index = oc - oc_start; |
|
|
|
rep(unit_idx, nr_units_in_tile) { |
|
|
|
size_t index = unit_start_idx + unit_idx; |
|
|
|
auto nh = index / units_w; |
|
|
|
auto nw = index % units_w; |
|
|
|
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; |
|
|
|
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; |
|
|
|
DISPATCH_CONV_WINOGRAD_BIAS( |
|
|
|
megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, |
|
|
|
bmode, nonline_mode, output_transform_buf, bias, output, |
|
|
|
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, |
|
|
|
oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, |
|
|
|
dst_dtype); |
|
|
|
} |
|
|
|
} |
|
|
|
constexpr size_t pack_size = 4; |
|
|
|
|
|
|
|
size_t OC = oc_end - oc_start; |
|
|
|
megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && |
|
|
|
oc_end % pack_size == 0, |
|
|
|
"NCHW44 Winograd filter transform requires OC is times of 4"); |
|
|
|
|
|
|
|
DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F63_mk4, cb, |
|
|
|
float, float, bmode, nonline_mode); |
|
|
|
#undef cb |
|
|
|
} |
|
|
|
|
|
|
|