GitOrigin-RevId: 6d4b225ea5
tags/v0.3.2
@@ -235,7 +235,7 @@ void StrategyHelper< | |||
input_filter_compute_type* input_transform_buf, | |||
input_filter_compute_type* transform_mid_buf, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t IC, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float rescale) { | |||
@@ -284,7 +284,7 @@ void StrategyHelper< | |||
const output_compute_type* bias, dst_type* output, | |||
output_compute_type* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, | |||
size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, | |||
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
@@ -296,7 +296,7 @@ void StrategyHelper< | |||
output_compute_type* mid_buf1 = transform_mid_buf; | |||
output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; | |||
OutputGetter<output_compute_type, dst_type> getter(dtype); | |||
OutputVisitor<layout, format> output_visitor(oc_end - oc_start); | |||
OutputVisitor<layout, format> output_visitor(OC); | |||
size_t oc = oc_start + oc_index; | |||
@@ -6,8 +6,7 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
@@ -44,8 +43,8 @@ public: | |||
input_filter_compute_type* input_transform_buf, | |||
input_filter_compute_type* transform_mid_buf, | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, | |||
size_t m, size_t r, | |||
size_t IC, size_t ic, size_t unit_idx, | |||
size_t nr_units_in_tile, size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float rescale = 1.0f); | |||
@@ -54,7 +53,7 @@ public: | |||
const output_compute_type* bias, dst_type* output, | |||
output_compute_type* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t oh_start, size_t ow_start, | |||
size_t OH, size_t OW, size_t oc_start, size_t oc_index, | |||
size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, | |||
size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, | |||
const std::vector<float>& interp_points, DType dtype, | |||
float input_filter_scale = 1.0f, // input_scale * filter_scale | |||
@@ -45,7 +45,6 @@ public: | |||
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack(); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->algoset() == | |||
//! TODO: threre should filter MK matmul | |||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
continue; | |||
} | |||
@@ -536,7 +536,6 @@ public: | |||
NonlineMode nonline_mode, size_t OH, size_t OW, \ | |||
size_t oc_start, size_t oc_end, size_t unit_start_idx, \ | |||
size_t nr_tiles_in_unit); \ | |||
}; | |||
#define MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(_strategy_cls_name) \ | |||
@@ -186,58 +186,56 @@ struct OutputTransform2X3_NCHW88 { | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
MEGDNN_MARK_USED_VAR(transform_mid_buf); | |||
megdnn_assert( | |||
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||
oc_end % 8 == 0, | |||
"Winograd output transform input param is not times of 8!"); | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
size_t OCB = (oc_end - oc_start) / 8; | |||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
size_t ocb = (oc - oc_start) / 8; | |||
size_t oc = oc_start + oc_index; | |||
size_t ocb = oc_index / 8; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 8>::load( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); | |||
#undef cb | |||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
//! 1 1 1 0 v00 v01 v02 v03 1 0 | |||
//! 0 1 -1 1 v10 v11 v12 v13 1 1 | |||
//! v20 v21 v22 v23 1 -1 | |||
//! v30 v31 v32 v33 0 1 | |||
#define cb(m) \ | |||
auto t0##m = v0##m + v1##m + v2##m; \ | |||
auto t1##m = v1##m - v2##m + v3##m; | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
UNROLL_CALL_NOWRAPPER(4, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
v##m##0 = t##m##0 + t##m##1 + t##m##2; \ | |||
v##m##1 = t##m##1 - t##m##2 + t##m##3; | |||
UNROLL_CALL_NOWRAPPER(2, cb); | |||
UNROLL_CALL_NOWRAPPER(2, cb); | |||
#undef cb | |||
Vector<float, 8> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 8>::load(bias + oc); | |||
Vector<float, 8> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 8>::load(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
UNROLL_CALL_RAW_D2(2, 2, cb); | |||
#undef cb | |||
} | |||
} | |||
#define out_save(oho, owo) \ | |||
do { \ | |||
size_t oh = oh_start + oho; \ | |||
@@ -252,8 +250,7 @@ struct OutputTransform2X3_NCHW88 { | |||
ow * 8); \ | |||
} \ | |||
} while (0); | |||
UNROLL_CALL_RAW_D2(2, 2, out_save); | |||
} | |||
UNROLL_CALL_RAW_D2(2, 2, out_save); | |||
} | |||
}; | |||
#undef CONCAT | |||
@@ -315,20 +312,40 @@ void winograd_nchw88_2x3_8x8_f::input(const float* input, | |||
} | |||
} | |||
void winograd_nchw88_2x3_8x8_f::output( | |||
const float* output_transform_buf, const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||
void winograd_nchw88_2x3_8x8_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, | |||
size_t oc_end, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||
__VA_ARGS__); | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float, | |||
float, bmode, nonline_mode, output_transform_buf, bias, output, | |||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||
"Winograd output transform input param is not times of 8!"); | |||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, | |||
float, float, bmode, nonline_mode, output_transform_buf, | |||
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, | |||
dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/common/unroll_macro.h" | |||
@@ -19,10 +20,10 @@ | |||
#include <x86intrin.h> | |||
#ifdef WIN32CMAKE | |||
#include <avxintrin.h> | |||
#include <smmintrin.h> | |||
#include <avx2intrin.h> | |||
#include <avxintrin.h> | |||
#include <fmaintrin.h> | |||
#include <smmintrin.h> | |||
#endif | |||
#include "midout.h" | |||
@@ -40,7 +41,7 @@ struct InputTransform6X3_NCHW88 { | |||
int ih_start, int iw_start, size_t IH, size_t IW, | |||
size_t ic, size_t IC) { | |||
MEGDNN_MARK_USED_VAR(patch); | |||
size_t IW8 = IW * 8; //! For nchw88 mode | |||
size_t IW8 = IW * 8; //! For nchw88 mode | |||
size_t iw8_start = iw_start * 8; //! For nchw88 mode | |||
size_t icb = ic / 8; | |||
if (!(inner && ic + 8 < IC)) { | |||
@@ -171,7 +172,7 @@ struct FilterTransform6X3_MCHW88 { | |||
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) { | |||
for (size_t icb = 0; icb < ICB; icb++) { | |||
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++){ | |||
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) { | |||
const float* fptr = filter + | |||
(ocb * ICB + icb) * 3 * 3 * 8 * 8 + | |||
ic_inner * 8; | |||
@@ -220,41 +221,39 @@ struct OutputTransform6X3_NCHW88 { | |||
float* output, float* transform_mid_buf, | |||
size_t oh_start, size_t ow_start, size_t OH, | |||
size_t OW, size_t oc_start, size_t oc_end, | |||
size_t unit_idx, size_t nr_units_in_tile, | |||
const DType& src_dtype, const DType& dst_dtype) { | |||
size_t oc_index, size_t unit_idx, | |||
size_t nr_units_in_tile, const DType& src_dtype, | |||
const DType& dst_dtype) { | |||
MEGDNN_MARK_USED_VAR(transform_mid_buf); | |||
megdnn_assert( | |||
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && | |||
oc_end % 8 == 0, | |||
"Winograd output transform input param is not times of 8!"); | |||
Op op(src_dtype, dst_dtype); | |||
//! AT * m * A | |||
size_t OCB = (oc_end - oc_start) / 8; | |||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
size_t ocb = (oc - oc_start) / 8; | |||
size_t oc = oc_start + oc_index; | |||
size_t ocb = oc_index / 8; | |||
#define cb(m, n) \ | |||
auto v##m##n = Vector<float, 8>::load( \ | |||
output_transform_buf + \ | |||
(m * alpha + n) * OCB * nr_units_in_tile * 8 + \ | |||
ocb * nr_units_in_tile * 8 + unit_idx * 8); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); | |||
#undef cb | |||
/** | |||
* A | |||
* | |||
* 1 0 0 0 0 0 | |||
* 1 1 1 1 1 1 | |||
* 1 -1 1 -1 1 -1 | |||
* 1 2 4 8 16 32 | |||
* 1 -2 4 -8 16 -32 | |||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
* 0 0.0 0 0 0 1 | |||
*/ | |||
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, | |||
v5subv6; | |||
/** | |||
* A | |||
* | |||
* 1 0 0 0 0 0 | |||
* 1 1 1 1 1 1 | |||
* 1 -1 1 -1 1 -1 | |||
* 1 2 4 8 16 32 | |||
* 1 -2 4 -8 16 -32 | |||
* 1 0.5 0.25 0.125 0.0625 0.03125 | |||
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 | |||
* 0 0.0 0 0 0 1 | |||
*/ | |||
Vector<float, 8> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; | |||
#define cb(m) \ | |||
v1addv2 = v1##m + v2##m; \ | |||
v1subv2 = v1##m - v2##m; \ | |||
@@ -269,7 +268,7 @@ struct OutputTransform6X3_NCHW88 { | |||
auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
UNROLL_CALL_NOWRAPPER(8, cb); | |||
#undef cb | |||
#define cb(m) \ | |||
@@ -286,22 +285,22 @@ struct OutputTransform6X3_NCHW88 { | |||
v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ | |||
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
UNROLL_CALL_NOWRAPPER(6, cb); | |||
#undef cb | |||
Vector<float, 8> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 8>::load(bias + oc); | |||
Vector<float, 8> vbias; | |||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
vbias = Vector<float, 8>::load(bias + oc); | |||
#define cb(m, n) v##m##n += vbias; | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
} | |||
if (bmode != BiasMode::BIAS) { | |||
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
UNROLL_CALL_RAW_D2(6, 6, cb); | |||
#undef cb | |||
} | |||
} | |||
#define out_save(oho, owo) \ | |||
do { \ | |||
size_t oh = oh_start + oho; \ | |||
@@ -316,8 +315,7 @@ struct OutputTransform6X3_NCHW88 { | |||
ow * 8); \ | |||
} \ | |||
} while (0); | |||
UNROLL_CALL_RAW_D2(6, 6, out_save); | |||
} | |||
UNROLL_CALL_RAW_D2(6, 6, out_save); | |||
} | |||
}; | |||
#undef CONCAT | |||
@@ -348,7 +346,8 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||
megdnn_assert(IC % 8 == 0); | |||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
auto units_w = | |||
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||
float* patch = transform_mid_buf; | |||
float* patchT = transform_mid_buf + 8 * alpha * alpha; | |||
@@ -379,25 +378,45 @@ void winograd_nchw88_6x3_8x8_f::input(const float* input, | |||
} | |||
} | |||
void winograd_nchw88_6x3_8x8_f::output( | |||
const float* output_transform_buf, const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, | |||
size_t oh_start, size_t ow_start, size_t OH, size_t OW, size_t oc_start, | |||
size_t oc_end, size_t unit_idx, size_t nr_units_in_tile) { | |||
void winograd_nchw88_6x3_8x8_f::output(const float* output_transform_buf, | |||
const float* bias, float* output, | |||
float* transform_mid_buf, BiasMode bmode, | |||
NonlineMode nonline_mode, size_t OH, | |||
size_t OW, size_t oc_start, | |||
size_t oc_end, size_t unit_start_idx, | |||
size_t nr_units_in_tile) { | |||
#define cb(_bmode, _nonline_op, ...) \ | |||
OutputTransform6X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform( \ | |||
__VA_ARGS__); | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, float, | |||
float, bmode, nonline_mode, output_transform_buf, bias, output, | |||
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, | |||
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); | |||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||
size_t OC = oc_end - oc_start; | |||
megdnn_assert(OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0, | |||
"Winograd output transform input param is not times of 8!"); | |||
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) { | |||
size_t oc_index = oc - oc_start; | |||
rep(unit_idx, nr_units_in_tile) { | |||
size_t index = unit_start_idx + unit_idx; | |||
auto nh = index / units_w; | |||
auto nw = index % units_w; | |||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||
DISPATCH_CONV_WINOGRAD_BIAS( | |||
megdnn_x86_winograd_nchw88_fp32_F63_8x8, cb, SIMDType::AVX2, | |||
float, float, bmode, nonline_mode, output_transform_buf, | |||
bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, | |||
oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, | |||
src_dtype, dst_dtype); | |||
} | |||
} | |||
#undef cb | |||
} | |||
} // namespace winograd | |||
} // namespace arm_common | |||
} // namespace x86 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |