GitOrigin-RevId: 6d4b225ea5
tags/v0.3.2
@@ -235,7 +235,7 @@ void StrategyHelper< | |||||
input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t m, size_t r, | size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float rescale) { | float rescale) { | ||||
@@ -284,7 +284,7 @@ void StrategyHelper< | |||||
const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, | NonlineMode nonline_mode, size_t oh_start, | ||||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, | |||||
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | ||||
size_t m, size_t r, | size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
@@ -296,7 +296,7 @@ void StrategyHelper< | |||||
output_compute_type* mid_buf1 = transform_mid_buf; | output_compute_type* mid_buf1 = transform_mid_buf; | ||||
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | ||||
OutputGetter<output_compute_type, dst_type> getter(dtype); | OutputGetter<output_compute_type, dst_type> getter(dtype); | ||||
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | |||||
OutputVisitor<layout, format> output_visitor(OC); | |||||
size_t oc = oc_start + oc_index; | size_t oc = oc_start + oc_index; | ||||
@@ -6,8 +6,7 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -44,8 +43,8 @@ public: | |||||
input_filter_compute_type* input_transform_buf, | input_filter_compute_type* input_transform_buf, | ||||
input_filter_compute_type* transform_mid_buf, | input_filter_compute_type* transform_mid_buf, | ||||
int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||||
size_t m, size_t r, | |||||
size_t IC, size_t ic, size_t unit_idx, | |||||
size_t nr_units_in_tile, size_t m, size_t r, | |||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float rescale = 1.0f); | float rescale = 1.0f); | ||||
@@ -54,7 +53,7 @@ public: | |||||
const output_compute_type* bias, dst_type* output, | const output_compute_type* bias, dst_type* output, | ||||
output_compute_type* transform_mid_buf, BiasMode bmode, | output_compute_type* transform_mid_buf, BiasMode bmode, | ||||
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | ||||
size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||||
size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, | |||||
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | ||||
const std::vector<float>& interp_points, DType dtype, | const std::vector<float>& interp_points, DType dtype, | ||||
float input_filter_scale = 1.0f, // input_scale * filter_scale | float input_filter_scale = 1.0f, // input_scale * filter_scale | ||||
@@ -45,7 +45,6 @@ public: | |||||
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->algoset() == | if (algo->algoset() == | ||||
//! TODO: threre should filter MK matmul | |||||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | ||||
continue; | continue; | ||||
} | } | ||||
@@ -536,7 +536,6 @@ public: | |||||
NonlineMode nonline_mode, size_t OH, size_t OW, \ | NonlineMode nonline_mode, size_t OH, size_t OW, \ | ||||
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | ||||
size_t nr_tiles_in_unit); \ | size_t nr_tiles_in_unit); \ | ||||
}; | }; | ||||
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | #define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | ||||
@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { | |||||
float* output, float* transform_mid_buf, | float* output, float* transform_mid_buf, | ||||
size_t oh_start, size_t ow_start, size_t OH, | size_t oh_start, size_t ow_start, size_t OH, | ||||
size_t OW, size_t oc_start, size_t oc_end, | size_t OW, size_t oc_start, size_t oc_end, | ||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
MEGDNN_MARK_USED_VAR(transform_mid_buf); | MEGDNN_MARK_USED_VAR(transform_mid_buf); | ||||
megdnn_assert( | |||||
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||||
oc_end % 8 == 0, | |||||
"Winograd output transform input param is not times of 8!"); | |||||
Op op(src_dtype, dst_dtype); | Op op(src_dtype, dst_dtype); | ||||
//! AT * m * A | //! AT * m * A | ||||
size_t OCB = (oc_end - oc_start) / 8; | size_t OCB = (oc_end - oc_start) / 8; | ||||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
size_t ocb = (oc - oc_start) / 8; | |||||
size_t oc = oc_start + oc_index; | |||||
size_t ocb = oc_index / 8; | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 8>::load( \ | auto v##m##n = Vector<float, 8>::load( \ | ||||
output_transform_buf + \ | output_transform_buf + \ | ||||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | ||||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | ocb * nr_units_in_tile * 8 + unit_idx * 8); | ||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||||
#undef cb | #undef cb | ||||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
//! v20 v21 v22 v23 1 -1 | |||||
//! v30 v31 v32 v33 0 1 | |||||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||||
//! v20 v21 v22 v23 1 -1 | |||||
//! v30 v31 v32 v33 0 1 | |||||
#define cb(m) \ | #define cb(m) \ | ||||
auto t0##m = v0##m + v1##m + v2##m; \ | auto t0##m = v0##m + v1##m + v2##m; \ | ||||
auto t1##m = v1##m - v2##m + v3##m; | auto t1##m = v1##m - v2##m + v3##m; | ||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
#undef cb | #undef cb | ||||
#define cb(m) \ | #define cb(m) \ | ||||
v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | ||||
v##m##1 = t##m##1 - t##m##2 + t##m##3; | v##m##1 = t##m##1 - t##m##2 + t##m##3; | ||||
UNROLL_CALL_NOWRAPPER(2, cb); | |||||
UNROLL_CALL_NOWRAPPER(2, cb); | |||||
#undef cb | #undef cb | ||||
Vector<float, 8> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 8>::load(bias + oc); | |||||
Vector<float, 8> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 8>::load(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | #define cb(m, n) v##m##n += vbias; | ||||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
#undef cb | #undef cb | ||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | ||||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||||
#undef cb | #undef cb | ||||
} | |||||
} | |||||
#define out_save(oho, owo) \ | #define out_save(oho, owo) \ | ||||
do { \ | do { \ | ||||
size_t oh = oh_start + oho; \ | size_t oh = oh_start + oho; \ | ||||
@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { | |||||
ow * 8); \ | ow * 8); \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
UNROLL_CALL_RAW_D2(2, 2, out_save); | |||||
} | |||||
UNROLL_CALL_RAW_D2(2, 2, out_save); | |||||
} | } | ||||
}; | }; | ||||
#undef CONCAT | #undef CONCAT | ||||
@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, | |||||
} | } | ||||
} | } | ||||
void winograd_nchw88_2x3_8x8_f::output( | |||||
const float* output_transform_buf, const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||||
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||||
void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, | |||||
size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | #define cb(_bmode, _nonline_op, ...) \ | ||||
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | ||||
__VA_ARGS__); | __VA_ARGS__); | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, | |||||
float, bmode, nonline_mode, output_transform_buf, bias, output, | |||||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||||
"Winograd output transform input param is not times of 8!"); | |||||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, | |||||
float, float, bmode, nonline_mode, output_transform_buf, | |||||
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||||
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, | |||||
dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
@@ -19,10 +20,10 @@ | |||||
#include <x86intrin.h> | #include <x86intrin.h> | ||||
#ifdef WIN32CMAKE | #ifdef WIN32CMAKE | ||||
#include <avxintrin.h> | |||||
#include <smmintrin.h> | |||||
#include <avx2intrin.h> | #include <avx2intrin.h> | ||||
#include <avxintrin.h> | |||||
#include <fmaintrin.h> | #include <fmaintrin.h> | ||||
#include <smmintrin.h> | |||||
#endif | #endif | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { | |||||
int ih_start, int iw_start, size_t IH, size_t IW, | int ih_start, int iw_start, size_t IH, size_t IW, | ||||
size_t ic, size_t IC) { | size_t ic, size_t IC) { | ||||
MEGDNN_MARK_USED_VAR(patch); | MEGDNN_MARK_USED_VAR(patch); | ||||
size_t IW8 = IW * 8; //! For nchw88 mode | |||||
size_t IW8 = IW * 8; //! For nchw88 mode | |||||
size_t iw8_start = iw_start * 8; //! For nchw88 mode | size_t iw8_start = iw_start * 8; //! For nchw88 mode | ||||
size_t icb = ic / 8; | size_t icb = ic / 8; | ||||
if (!(inner && ic + 8 < IC)) { | if (!(inner && ic + 8 < IC)) { | ||||
@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { | |||||
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | ||||
for (size_t icb = 0; icb < ICB; icb++) { | for (size_t icb = 0; icb < ICB; icb++) { | ||||
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ | |||||
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) { | |||||
const float* fptr = filter + | const float* fptr = filter + | ||||
(ocb * ICB + icb) * 3 * 3 * 8 * 8 + | (ocb * ICB + icb) * 3 * 3 * 8 * 8 + | ||||
ic_inner * 8; | ic_inner * 8; | ||||
@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { | |||||
float* output, float* transform_mid_buf, | float* output, float* transform_mid_buf, | ||||
size_t oh_start, size_t ow_start, size_t OH, | size_t oh_start, size_t ow_start, size_t OH, | ||||
size_t OW, size_t oc_start, size_t oc_end, | size_t OW, size_t oc_start, size_t oc_end, | ||||
size_t unit_idx, size_t nr_units_in_tile, | |||||
const DType& src_dtype, const DType& dst_dtype) { | |||||
size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, | |||||
const DType& dst_dtype) { | |||||
MEGDNN_MARK_USED_VAR(transform_mid_buf); | MEGDNN_MARK_USED_VAR(transform_mid_buf); | ||||
megdnn_assert( | |||||
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||||
oc_end % 8 == 0, | |||||
"Winograd output transform input param is not times of 8!"); | |||||
Op op(src_dtype, dst_dtype); | Op op(src_dtype, dst_dtype); | ||||
//! AT * m * A | //! AT * m * A | ||||
size_t OCB = (oc_end - oc_start) / 8; | size_t OCB = (oc_end - oc_start) / 8; | ||||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
size_t ocb = (oc - oc_start) / 8; | |||||
size_t oc = oc_start + oc_index; | |||||
size_t ocb = oc_index / 8; | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
auto v##m##n = Vector<float, 8>::load( \ | auto v##m##n = Vector<float, 8>::load( \ | ||||
output_transform_buf + \ | output_transform_buf + \ | ||||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | ||||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | ocb * nr_units_in_tile * 8 + unit_idx * 8); | ||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||||
#undef cb | #undef cb | ||||
/** | |||||
* A | |||||
* | |||||
* 1 0 0 0 0 0 | |||||
* 1 1 1 1 1 1 | |||||
* 1 -1 1 -1 1 -1 | |||||
* 1 2 4 8 16 32 | |||||
* 1 -2 4 -8 16 -32 | |||||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
* 0 0.0 0 0 0 1 | |||||
*/ | |||||
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, | |||||
v5subv6; | |||||
/** | |||||
* A | |||||
* | |||||
* 1 0 0 0 0 0 | |||||
* 1 1 1 1 1 1 | |||||
* 1 -1 1 -1 1 -1 | |||||
* 1 2 4 8 16 32 | |||||
* 1 -2 4 -8 16 -32 | |||||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||||
* 0 0.0 0 0 0 1 | |||||
*/ | |||||
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||||
#define cb(m) \ | #define cb(m) \ | ||||
v1addv2 = v1##m + v2##m; \ | v1addv2 = v1##m + v2##m; \ | ||||
v1subv2 = v1##m - v2##m; \ | v1subv2 = v1##m - v2##m; \ | ||||
@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { | |||||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | ||||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#undef cb | #undef cb | ||||
#define cb(m) \ | #define cb(m) \ | ||||
@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { | |||||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | ||||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | ||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | #undef cb | ||||
Vector<float, 8> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 8>::load(bias + oc); | |||||
Vector<float, 8> vbias; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
vbias = Vector<float, 8>::load(bias + oc); | |||||
#define cb(m, n) v##m##n += vbias; | #define cb(m, n) v##m##n += vbias; | ||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
#undef cb | #undef cb | ||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
} | |||||
if (bmode != BiasMode::BIAS) { | |||||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | #define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | ||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||||
#undef cb | #undef cb | ||||
} | |||||
} | |||||
#define out_save(oho, owo) \ | #define out_save(oho, owo) \ | ||||
do { \ | do { \ | ||||
size_t oh = oh_start + oho; \ | size_t oh = oh_start + oho; \ | ||||
@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { | |||||
ow * 8); \ | ow * 8); \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
UNROLL_CALL_RAW_D2(6, 6, out_save); | |||||
} | |||||
UNROLL_CALL_RAW_D2(6, 6, out_save); | |||||
} | } | ||||
}; | }; | ||||
#undef CONCAT | #undef CONCAT | ||||
@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||||
megdnn_assert(IC % 8 == 0); | megdnn_assert(IC % 8 == 0); | ||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | // OW = IW + 2 * PW - KERNEL_SIZE + 1 | ||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
auto units_w = | |||||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
float* patch = transform_mid_buf; | float* patch = transform_mid_buf; | ||||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | float* patchT = transform_mid_buf + 8 * alpha * alpha; | ||||
@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||||
} | } | ||||
} | } | ||||
void winograd_nchw88_6x3_8x8_f::output( | |||||
const float* output_transform_buf, const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||||
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||||
void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf, | |||||
const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, | |||||
NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, | |||||
size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | #define cb(_bmode, _nonline_op, ...) \ | ||||
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | ||||
__VA_ARGS__); | __VA_ARGS__); | ||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, | |||||
float, bmode, nonline_mode, output_transform_buf, bias, output, | |||||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||||
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
size_t OC = oc_end - oc_start; | |||||
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||||
"Winograd output transform input param is not times of 8!"); | |||||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, | |||||
float, float, bmode, nonline_mode, output_transform_buf, | |||||
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||||
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, | |||||
src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | #undef cb | ||||
} | } | ||||
} // namespace winograd | } // namespace winograd | ||||
} // namespace arm_common | |||||
} // namespace x86 | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |