GitOrigin-RevId: 909503a90d
dev-support-lite-fork-debug-mode
@@ -453,19 +453,21 @@ public: | |||||
}; | }; | ||||
//! param for winograd algos. | //! param for winograd algos. | ||||
struct WinogradParam { | struct WinogradParam { | ||||
uint32_t channel_block_size; | uint32_t channel_block_size; | ||||
uint32_t output_block_size; | uint32_t output_block_size; | ||||
uint32_t tile_size; | uint32_t tile_size; | ||||
uint32_t filter_size; | |||||
bool operator==(const WinogradParam& rhs) const { | bool operator==(const WinogradParam& rhs) const { | ||||
return channel_block_size == rhs.channel_block_size && | return channel_block_size == rhs.channel_block_size && | ||||
output_block_size == rhs.output_block_size && | output_block_size == rhs.output_block_size && | ||||
tile_size == rhs.tile_size; | |||||
tile_size == rhs.tile_size && filter_size == rhs.filter_size; | |||||
} | } | ||||
std::string to_string() const; | std::string to_string() const; | ||||
}; | }; | ||||
static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0}; | |||||
static constexpr WinogradParam INVALID_WINOGRAD_PARAM = {0, 0, 0, 0}; | |||||
struct DirectParam { | struct DirectParam { | ||||
std::string to_string() const { return ""; } | std::string to_string() const { return ""; } | ||||
@@ -14,7 +14,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 2, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 2, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -33,7 +33,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 4, m_tile_size, 5}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -51,7 +51,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 6, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -69,7 +69,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {8, 2, m_tile_size}); | |||||
m_matmul_algo->name(), {8, 2, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -221,7 +221,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {8, 2, m_tile_size}); | |||||
m_matmul_algo->name(), {8, 2, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -239,7 +239,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 2, m_tile_size}, | |||||
m_matmul_algo->name(), {4, 2, m_tile_size, 3}, | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
@@ -258,7 +258,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {8, 2, m_tile_size}, | |||||
m_matmul_algo->name(), {8, 2, m_tile_size, 3}, | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
@@ -176,7 +176,9 @@ template <typename T> | |||||
struct NCHW44ParamTrait; | struct NCHW44ParamTrait; | ||||
std::string ConvBias::WinogradParam::to_string() const { | std::string ConvBias::WinogradParam::to_string() const { | ||||
return ssprintf("%u:%u:%u", channel_block_size, output_block_size, tile_size); | |||||
return ssprintf( | |||||
"%u:%u:%u:%u", channel_block_size, output_block_size, tile_size, | |||||
filter_size); | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
@@ -165,6 +165,18 @@ | |||||
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ | cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ | ||||
cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) | cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) | ||||
#define UNROLL_RAW_4x2(cb, v0, a...) \ | |||||
cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \ | |||||
cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a) | |||||
#define UNROLL_RAW_5x2(cb, v0, a...) \ | |||||
UNROLL_RAW_4x2(cb, v0, ##a) \ | |||||
cb(4, 0, ##a) cb(4, 1, ##a) | |||||
#define UNROLL_RAW_6x2(cb, v0, a...) \ | |||||
UNROLL_RAW_5x2(cb, v0, ##a) \ | |||||
cb(5, 0, ##a) cb(5, 1, ##a) | |||||
#define UNROLL_CALL0_D2(step, step2, cb, v...) \ | #define UNROLL_CALL0_D2(step, step2, cb, v...) \ | ||||
UNROLL_RAW_##step##x##step2(cb, 0, ##v) | UNROLL_RAW_##step##x##step2(cb, 0, ##v) | ||||
#define UNROLL_CALL1_D2(step, step2, cb, v...) \ | #define UNROLL_CALL1_D2(step, step2, cb, v...) \ | ||||
@@ -42,7 +42,7 @@ public: | |||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), | ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), | ||||
{1, 2, UNIT_TILE_SIZE}); | |||||
{1, 2, UNIT_TILE_SIZE, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -74,7 +74,7 @@ public: | |||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), | ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()), | ||||
{4, 2, UNIT_TILE_SIZE}); | |||||
{4, 2, UNIT_TILE_SIZE, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -106,7 +106,7 @@ public: | |||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), | ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), | ||||
{1, 2, UNIT_TILE_SIZE}); | |||||
{1, 2, UNIT_TILE_SIZE, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -138,7 +138,7 @@ public: | |||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), | ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()), | ||||
{8, 2, UNIT_TILE_SIZE}); | |||||
{8, 2, UNIT_TILE_SIZE, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -84,6 +84,38 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||||
AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, | AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, | ||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | ||||
/* ======================= AlgoFP32WinogradF43 ======================== */ | |||||
bool ConvBiasImpl::AlgoFP32WinogradF43::usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
MEGDNN_MARK_USED_VAR(param); | |||||
MIDOUT_BEGIN(megdnn_fallback_winograd_fp32, 5, 0) { | |||||
using Strategy = winograd::winograd_4x3_1x1_f; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy>(strategy, m_tile_size, param) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
!param.filter_meta.should_flip && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::Float32; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( | |||||
AlgoFP32WinogradF43, winograd::winograd_4x3_1x1_f, | |||||
megdnn_fallback_winograd_fp32, param::MatrixMul::Format::DEFAULT); | |||||
/* ======================= AlgoFP32WinogradF54 ======================== */ | /* ======================= AlgoFP32WinogradF54 ======================== */ | ||||
bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | bool ConvBiasImpl::AlgoFP32WinogradF54::usable( | ||||
@@ -14,7 +14,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 2, m_tile_size}); | |||||
m_matmul_algo->name(), {4, 2, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -31,7 +31,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 6, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 6, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -42,6 +42,28 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) | MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F63_FP32) | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF43 final : public AlgoBase { | |||||
public: | |||||
AlgoFP32WinogradF43( | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) | |||||
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | |||||
const char* name() const override { | |||||
if (m_name.empty()) { | |||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | |||||
m_matmul_algo->name(), {1, 4, m_tile_size, 3}); | |||||
} | |||||
return m_name.c_str(); | |||||
} | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
MEGDNN_DECL_ALGO_TYPE(GI_COMMON_WINOGRAD_F43_FP32); | |||||
}; | |||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoFP32WinogradF63_4x4( | AlgoFP32WinogradF63_4x4( | ||||
@@ -50,7 +72,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 6, m_tile_size}); | |||||
m_matmul_algo->name(), {4, 6, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -67,7 +89,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 5, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 5, m_tile_size, 4}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -86,7 +108,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {1, 4, m_tile_size}); | |||||
m_matmul_algo->name(), {1, 4, m_tile_size, 5}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -106,7 +128,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 2, m_tile_size}, | |||||
m_matmul_algo->name(), {4, 2, m_tile_size, 3}, | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
@@ -124,7 +146,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 6, m_tile_size}, | |||||
m_matmul_algo->name(), {4, 6, m_tile_size, 3}, | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
@@ -142,7 +164,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {4, 7, m_tile_size}, | |||||
m_matmul_algo->name(), {4, 7, m_tile_size, 3}, | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
@@ -155,6 +155,124 @@ struct FilterTransform6X3 { | |||||
#undef FILTER_TRANSFORM | #undef FILTER_TRANSFORM | ||||
#undef GET_VECTOR_ELEM | #undef GET_VECTOR_ELEM | ||||
template <param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> | |||||
struct FilterTransform4X3 { | |||||
#define FILTER_TRANSFORM(d, wd, ADDC, SUBC, MULC) \ | |||||
do { \ | |||||
wd##0 = MULC(d##0, 0.25f); \ | |||||
auto tmp0 = MULC(ADDC(d##0, d##2), -0.1666667f); \ | |||||
auto tmp1 = MULC(d##1, -0.1666667f); \ | |||||
wd##1 = ADDC(tmp0, tmp1); \ | |||||
wd##2 = SUBC(tmp0, tmp1); \ | |||||
tmp0 = ADDC(MULC(d##0, 0.0416667f), MULC(d##2, 0.1666667f)); \ | |||||
tmp1 = MULC(d##1, 0.0833333f); \ | |||||
wd##3 = ADDC(tmp0, tmp1); \ | |||||
wd##4 = SUBC(tmp0, tmp1); \ | |||||
wd##5 = d##2; \ | |||||
} while (0); | |||||
static void transform( | |||||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | |||||
size_t OC, size_t IC, size_t oc_start, size_t oc_end) { | |||||
constexpr size_t alpha = 4 + 3 - 1; | |||||
size_t OCB = OC / 4; | |||||
size_t ICB = IC / 4; | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
rep(ic, IC) { | |||||
const float* fptr = filter + (oc * IC + ic) * 3 * 3; | |||||
GI_FLOAT32_t g0 = GiLoadFloat32(fptr); | |||||
GI_FLOAT32_t g1 = GiLoadFloat32(fptr + 3); | |||||
GI_FLOAT32_t g2 = GiLoadFloat32(fptr + 6 - 1); | |||||
GI_FLOAT32_t zeros = GiZeroFloat32(); | |||||
g2 = GiExtqFloat32(g2, zeros, 1); | |||||
#define cb(i) GI_FLOAT32_t wd##i = GiZeroFloat32(); | |||||
#if MEGDNN_AARCH64 | |||||
UNROLL_CALL_NOWRAPPER(8, cb); | |||||
#else | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#endif | |||||
#undef cb | |||||
FILTER_TRANSFORM(g, wd, ADDF, SUBF, MULSF); | |||||
size_t ocb = oc / 4; | |||||
size_t oc4 = oc % 4; | |||||
size_t icb = ic / 4; | |||||
size_t ic4 = ic % 4; | |||||
#if MEGDNN_AARCH64 | |||||
#define cb(i) GI_FLOAT32_V2_t wdt##i; | |||||
UNROLL_CALL_NOWRAPPER(3, cb); | |||||
#undef cb | |||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
TRANSPOSE_8x3(wd, wdt); | |||||
FILTER_TRANSFORM(wdt, ret, ADDFV2, SUBFV2, MULSFV2); | |||||
#define cb(i) GiStoreFloat32V2(transform_mid_buf + i * alpha, ret##i); | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
if (format == param::MatrixMul::Format::DEFAULT) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} else { | |||||
filter_transform_buf | |||||
[(i * alpha + j) * OCB * ICB * 4 * 4 + | |||||
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
} | |||||
#else | |||||
#define cb(i) \ | |||||
do { \ | |||||
mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0) * 0.25f; \ | |||||
auto tmp0 = \ | |||||
(GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * -0.1666667f; \ | |||||
auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.1666667f; \ | |||||
mid_buf1[1] = tmp0 + tmp1; \ | |||||
mid_buf1[2] = tmp0 - tmp1; \ | |||||
tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0416667f + \ | |||||
GET_VECTOR_ELEM(wd, i, 2) * 0.1666667f; \ | |||||
tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0833333f; \ | |||||
mid_buf1[3] = tmp0 + tmp1; \ | |||||
mid_buf1[4] = tmp0 - tmp1; \ | |||||
mid_buf1[5] = GET_VECTOR_ELEM(wd, i, 2); \ | |||||
mid_buf1 += 6; \ | |||||
} while (0); | |||||
#define GET_VECTOR_ELEM(s, i, idx) GiExtractLane##idx##Float32(CONCAT(s, i)) | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
if (format == param::MatrixMul::Format::DEFAULT) { | |||||
filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} else { | |||||
filter_transform_buf | |||||
[(i * alpha + j) * OCB * ICB * 4 * 4 + | |||||
ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef FILTER_TRANSFORM | |||||
#undef GET_VECTOR_ELEM | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -116,6 +116,46 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { | |||||
GiReinterpretqFloat32ToS64(b7.val[1]))); \ | GiReinterpretqFloat32ToS64(b7.val[1]))); \ | ||||
} while (0); | } while (0); | ||||
#define TRANSPOSE_6x6(a, ret) \ | |||||
do { \ | |||||
auto b0 = GiZipqFloat32(CONCAT(a, 00), CONCAT(a, 10)); \ | |||||
auto b1 = GiZipqFloat32(CONCAT(a, 01), CONCAT(a, 11)); \ | |||||
auto b2 = GiZipqFloat32(CONCAT(a, 20), CONCAT(a, 30)); \ | |||||
auto b3 = GiZipqFloat32(CONCAT(a, 21), CONCAT(a, 31)); \ | |||||
auto b4 = GiZipqFloat32(CONCAT(a, 40), CONCAT(a, 50)); \ | |||||
auto b5 = GiZipqFloat32(CONCAT(a, 41), CONCAT(a, 51)); \ | |||||
CONCAT(ret, 00) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 01) = b4.val[0]; \ | |||||
CONCAT(ret, 10) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[0]))); \ | |||||
CONCAT(ret, 11) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]))); \ | |||||
CONCAT(ret, 20) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 21) = b4.val[1]; \ | |||||
CONCAT(ret, 30) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b0.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b2.val[1]))); \ | |||||
CONCAT(ret, 31) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b4.val[1]), \ | |||||
GiReinterpretqFloat32ToS64(b5.val[1]))); \ | |||||
CONCAT(ret, 40) = GiReinterpretqS64ToFloat32(GiZip1qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 41) = b5.val[0]; \ | |||||
CONCAT(ret, 50) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b1.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b3.val[0]))); \ | |||||
CONCAT(ret, 51) = GiReinterpretqS64ToFloat32(GiZip2qS64( \ | |||||
GiReinterpretqFloat32ToS64(b5.val[0]), \ | |||||
GiReinterpretqFloat32ToS64(b4.val[0]))); \ | |||||
} while (0); | |||||
#define TRANSPOSE_8x3(a, ret) \ | #define TRANSPOSE_8x3(a, ret) \ | ||||
auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | auto b0 = GiZipqFloat32(CONCAT(a, 0), CONCAT(a, 1)); \ | ||||
auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | auto b1 = GiZipqFloat32(CONCAT(a, 2), CONCAT(a, 3)); \ | ||||
@@ -12,6 +12,8 @@ MEGDNN_REG_WINOGRAD_STRATEGY( | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) | ||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 3, 1, 1, winograd_4x3_1x1_f) | |||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f) | MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f) | ||||
MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f) | MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f) | ||||
@@ -0,0 +1,372 @@ | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/filter_transform.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/helper.h" | |||||
#include "src/fallback/conv_bias/gi/fp32/strategy.h" | |||||
#include "src/fallback/conv_bias/winograd/winograd.h" | |||||
#include "src/fallback/elemwise_helper/op_unary.h" | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_fallback_winograd_fp32_F43) | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
namespace { | |||||
/** | |||||
* input transform | |||||
* | |||||
* wd0 = 4 * (d0 - d2) - (d2 - d4) | |||||
* wd1 = -4 * (d1 + d2) + (d3 + d4) | |||||
* wd2 = 4 * (d1 - d2) + (d4 - d3) | |||||
* wd3 = 2 * (d3 - d1) - (d2 - d4) | |||||
* wd4 = -2 * (d3 - d1) - (d2 - d4) | |||||
* wd5 = -4 * (d3 - d1) + (d5 - d3) | |||||
*/ | |||||
#define INPUT_TRANSFORM(d, wd, i) \ | |||||
do { \ | |||||
auto tmp0 = SUBF(d##2##i, d##4##i); \ | |||||
auto tmp1 = SUBF(d##3##i, d##1##i); \ | |||||
wd##0##i = SUBF(MULSF(SUBF(d##0##i, d##2##i), 4.0f), tmp0); \ | |||||
wd##1##i = SUBF(ADDF(d##3##i, d##4##i), MULSF(ADDF(d##1##i, d##2##i), 4.0f)); \ | |||||
wd##2##i = ADDF(MULSF(SUBF(d##1##i, d##2##i), 4.0f), SUBF(d##4##i, d##3##i)); \ | |||||
wd##3##i = SUBF(MULSF(tmp1, 2.0f), tmp0); \ | |||||
wd##4##i = SUBF(MULSF(tmp1, -2.0f), tmp0); \ | |||||
wd##5##i = SUBF(SUBF(d##5##i, d##3##i), MULSF(tmp1, 4.0f)); \ | |||||
} while (0); | |||||
#define INPUT_TRANSFORM_V2(d, wd) \ | |||||
INPUT_TRANSFORM(d, wd, 0); \ | |||||
INPUT_TRANSFORM(d, wd, 1); | |||||
#define GET_VECTOR_HIGH_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##1) | |||||
#define GET_VECTOR_LOW_ELEM(s, i, idx) GiExtractLane##idx##Float32(s##i##0) | |||||
struct InputTransform4X3 { | |||||
template <bool inner> | |||||
static void transform( | |||||
const float* input, float* input_transform_buf, float* transform_mid_buf, | |||||
int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, | |||||
size_t unit_idx, size_t nr_units_in_tile) { | |||||
constexpr size_t alpha = 4 + 3 - 1; | |||||
if (!inner) { | |||||
memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); | |||||
} | |||||
#define cb(i, j) GI_FLOAT32_t d##i##j; | |||||
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); | |||||
#undef cb | |||||
if (inner) { | |||||
const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; | |||||
#define cb(i, j) d##i##j = GiLoadFloat32(input_ptr + IW * i + 4 * j); | |||||
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); | |||||
#undef cb | |||||
d50 = GiLoadFloat32(input_ptr + IW * 5); | |||||
d51 = GiLoadFloat32LowHalf(input_ptr + IW * 5 + 4); | |||||
} else { | |||||
int ih0_act = std::max<int>(ih_start, 0), | |||||
ih1_act = std::min<int>(ih_start + alpha, IH), | |||||
iw0_act = std::max<int>(iw_start, 0), | |||||
iw1_act = std::min<int>(iw_start + alpha, IW); | |||||
for (int ih = ih0_act; ih < ih1_act; ++ih) { | |||||
for (int iw = iw0_act; iw < iw1_act; ++iw) { | |||||
size_t iho = ih - ih_start, iwo = iw - iw_start; | |||||
transform_mid_buf[iho * alpha + iwo] = | |||||
input[ic * IH * IW + ih * IW + iw]; | |||||
} | |||||
} | |||||
#define cb(i, j) d##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j); | |||||
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); | |||||
#undef cb | |||||
d50 = GiLoadFloat32(transform_mid_buf + alpha * 5); | |||||
d51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4); | |||||
} | |||||
#define cb(i, j) GI_FLOAT32_t wd##i##j; | |||||
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); | |||||
#undef cb | |||||
INPUT_TRANSFORM_V2(d, wd); | |||||
#if MEGDNN_AARCH64 | |||||
#define cb(i, j) GI_FLOAT32_t ret##i##j; | |||||
UNROLL_CALL_NOWRAPPER_D2(6, 2, cb); | |||||
#undef cb | |||||
TRANSPOSE_6x6(wd, d); | |||||
INPUT_TRANSFORM_V2(d, ret); | |||||
#define cb(i, j) GiStoreFloat32(transform_mid_buf + i * alpha + j * 4, ret##i##j); | |||||
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); | |||||
#undef cb | |||||
GiStoreFloat32(transform_mid_buf + 5 * alpha, ret50); | |||||
float tmp[4]; | |||||
GiStoreFloat32(tmp, ret51); | |||||
memcpy(transform_mid_buf + 5 * alpha + 4, tmp, sizeof(float) * 2); | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf | |||||
[(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = | |||||
transform_mid_buf[j * alpha + i]; | |||||
} | |||||
#else | |||||
//! 4 0 0 0 0 0 | |||||
//! 0 -4 4 -2 2 4 | |||||
//! -5 -4 -4 -1 -1 0 | |||||
//! 0 1 -1 2 -2 -5 | |||||
//! 1 1 1 1 1 0 | |||||
//! 0 0 0 0 0 1 | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) - GET_VECTOR_HIGH_ELEM(wd, i, 0); \ | |||||
auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1); \ | |||||
mid_buf1[0] = \ | |||||
(GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ | |||||
4.0f - \ | |||||
tmp0; \ | |||||
mid_buf1[1] = \ | |||||
(GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ | |||||
-4.0f + \ | |||||
(GET_VECTOR_LOW_ELEM(wd, i, 3) + GET_VECTOR_HIGH_ELEM(wd, i, 0)); \ | |||||
mid_buf1[2] = \ | |||||
(GET_VECTOR_LOW_ELEM(wd, i, 1) - GET_VECTOR_LOW_ELEM(wd, i, 2)) * \ | |||||
4.0f + \ | |||||
(GET_VECTOR_HIGH_ELEM(wd, i, 0) - GET_VECTOR_LOW_ELEM(wd, i, 3)); \ | |||||
mid_buf1[3] = 2.0f * tmp1 - tmp0; \ | |||||
mid_buf1[4] = -2.0f * tmp1 - tmp0; \ | |||||
mid_buf1[5] = -4.0f * tmp1 + (GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ | |||||
GET_VECTOR_LOW_ELEM(wd, i, 3)); \ | |||||
mid_buf1 += 6; \ | |||||
} while (0); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(6, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
rep(i, alpha) rep(j, alpha) { | |||||
input_transform_buf | |||||
[(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = | |||||
transform_mid_buf[i * alpha + j]; | |||||
} | |||||
#endif | |||||
} | |||||
}; | |||||
#undef INPUT_TRANSFORM_V2 | |||||
#undef INPUT_TRANSFORM | |||||
/** | |||||
* Output Transform: use fma | |||||
* | |||||
* s0 = m0 + (m1 + m2) + (m3 + m4) | |||||
* s1 = (m1 - m2) + 2 * (m3 - m4) | |||||
* s2 = (m1 + m2) + 4 * (m3 + m4) | |||||
* s3 = (m1 - m2) + 8 * (m3 - m4) + m5 | |||||
*/ | |||||
#define OUTPUT_TRANSFORM(m, s, i) \ | |||||
do { \ | |||||
auto m1addm2 = ADDF(m##1##i, m##2##i); \ | |||||
auto m1subm2 = SUBF(m##1##i, m##2##i); \ | |||||
auto m3addm4 = ADDF(m##3##i, m##4##i); \ | |||||
auto m3subm4 = SUBF(m##3##i, m##4##i); \ | |||||
s##0##i = m##0##i; \ | |||||
s##0##i = ADDF(s##0##i, m1addm2); \ | |||||
s##0##i = ADDF(s##0##i, m3addm4); \ | |||||
s##1##i = m1subm2; \ | |||||
s##1##i = GiMultiplyAddScalarFloat32(s##1##i, m3subm4, 2.0f); \ | |||||
s##2##i = m1addm2; \ | |||||
s##2##i = GiMultiplyAddScalarFloat32(s##2##i, m3addm4, 4.0f); \ | |||||
s##3##i = m1subm2; \ | |||||
s##3##i = GiMultiplyAddScalarFloat32(s##3##i, m3subm4, 8.0f); \ | |||||
s##3##i = ADDF(s##3##i, m##5##i); \ | |||||
} while (0); | |||||
#define OUTPUT_TRANSFORM_V2(m, s) \ | |||||
OUTPUT_TRANSFORM(m, s, 0); \ | |||||
OUTPUT_TRANSFORM(m, s, 1); | |||||
template <BiasMode bmode, typename Op> | |||||
struct OutputTransform4X3 { | |||||
static void transform( | |||||
const float* output_transform_buf, const float* bias, float* output, | |||||
float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, | |||||
size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { | |||||
constexpr size_t alpha = 4 + 3 - 1; | |||||
Op op(src_dtype, dst_dtype); | |||||
float* mid_buf1 = transform_mid_buf; | |||||
//! AT * m * A | |||||
size_t OC = oc_end - oc_start; | |||||
size_t oc = oc_start + oc_index; | |||||
#define cb(m, n) \ | |||||
transform_mid_buf[m * alpha + n] = output_transform_buf \ | |||||
[(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; | |||||
UNROLL_CALL_NOWRAPPER_D2(6, 6, cb); | |||||
#undef cb | |||||
#define cb(i, j) auto m##i##j = GiLoadFloat32(transform_mid_buf + alpha * i + 4 * j); | |||||
UNROLL_CALL_NOWRAPPER_D2(5, 2, cb); | |||||
#undef cb | |||||
GI_FLOAT32_t m50, m51; | |||||
m50 = GiLoadFloat32(transform_mid_buf + alpha * 5); | |||||
m51 = GiLoadFloat32LowHalf(transform_mid_buf + alpha * 5 + 4); | |||||
#define cb(i, j) GI_FLOAT32_t s##i##j; | |||||
UNROLL_CALL_NOWRAPPER_D2(4, 2, cb); | |||||
#undef cb | |||||
OUTPUT_TRANSFORM_V2(m, s); | |||||
/** | |||||
* Output transform: s * A | |||||
* | |||||
* 1 0 0 0 | |||||
* 1 1 1 1 | |||||
* 1 -1 1 -1 | |||||
* 1 2 4 8 | |||||
* 1 -2 4 -8 | |||||
* 0 0 0 1 | |||||
*/ | |||||
#define cb(i) \ | |||||
do { \ | |||||
auto m1addm2 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m1subm2 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ | |||||
auto m3addm4 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
auto m3subm4 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ | |||||
mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4; \ | |||||
mid_buf1[1] = m1subm2 + 2.f * m3subm4; \ | |||||
mid_buf1[2] = m1addm2 + 4.f * m3addm4; \ | |||||
mid_buf1[3] = m1subm2 + 8.f * m3subm4 + GET_VECTOR_HIGH_ELEM(s, i, 1); \ | |||||
mid_buf1 += 4; \ | |||||
} while (0); | |||||
mid_buf1 = transform_mid_buf; | |||||
UNROLL_CALL_NOWRAPPER(4, cb); | |||||
mid_buf1 = transform_mid_buf; | |||||
#undef cb | |||||
if (oh_start + 4 <= OH && ow_start + 4 <= OW) { | |||||
GI_FLOAT32_t bias0; | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias0 = GiBroadcastFloat32(bias[oc]); | |||||
} | |||||
rep(i, 4) { | |||||
size_t oh = oh_start + i; | |||||
GI_FLOAT32_t item0 = GiLoadFloat32(mid_buf1); | |||||
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
} else if (bmode == BiasMode::BIAS) { | |||||
bias0 = GiLoadFloat32(bias + oc * OH * OW + oh * OW + ow_start); | |||||
item0 = GiAddFloat32(item0, bias0); | |||||
} | |||||
item0 = op(item0); | |||||
GiStoreFloat32(output + oc * OH * OW + oh * OW + ow_start, item0); | |||||
mid_buf1 += 4; | |||||
} | |||||
} else { | |||||
for (size_t oho = 0; oho < 4 && oh_start + oho < OH; ++oho) { | |||||
for (size_t owo = 0; owo < 4 && ow_start + owo < OW; ++owo) { | |||||
size_t oh = oh_start + oho; | |||||
size_t ow = ow_start + owo; | |||||
float res = mid_buf1[oho * 4 + owo]; | |||||
if (bmode == BiasMode::BIAS) { | |||||
res += bias[oc * OH * OW + oh * OW + ow]; | |||||
} else if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
res += bias[oc]; | |||||
} | |||||
res = op(res); | |||||
output[oc * OH * OW + oh * OW + ow] = res; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
#undef GET_VECTOR_HIGH_ELEM | |||||
#undef GET_VECTOR_LOW_ELEM | |||||
#undef OUTPUT_TRANSFORM_V2 | |||||
#undef OUTPUT_TRANSFORM | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
namespace winograd { | |||||
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x3_1x1_f) | |||||
void winograd_4x3_1x1_f::filter( | |||||
const float* filter, float* filter_transform_buf, float* transform_mid_buf, | |||||
size_t OC, size_t IC, size_t oc_start, size_t oc_end) { | |||||
FilterTransform4X3<param::MatrixMul::Format::DEFAULT>::transform( | |||||
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); | |||||
} | |||||
void winograd_4x3_1x1_f::input( | |||||
const float* input, float* input_transform_buf, float* transform_mid_buf, | |||||
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
constexpr int alpha = 3 + 4 - 1; | |||||
// OW = IW + 2 * PW - KERNEL_SIZE + 1 | |||||
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); | |||||
rep(ic, IC) { | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
size_t nh = index / units_w; | |||||
size_t nw = index % units_w; | |||||
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; | |||||
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; | |||||
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) && | |||||
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) { | |||||
InputTransform4X3::transform<true>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} else { | |||||
InputTransform4X3::transform<false>( | |||||
input, input_transform_buf, transform_mid_buf, ih_start, | |||||
iw_start, ic, IH, IW, IC, unit_idx, nr_units_in_tile); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void winograd_4x3_1x1_f::output( | |||||
const float* output_transform_buf, const float* bias, float* output, | |||||
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, | |||||
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, | |||||
size_t nr_units_in_tile) { | |||||
#define cb(_bmode, _nonline_op, ...) \ | |||||
OutputTransform4X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); | |||||
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE); | |||||
for (size_t oc = oc_start; oc < oc_end; oc++) { | |||||
size_t oc_index = oc - oc_start; | |||||
rep(unit_idx, nr_units_in_tile) { | |||||
size_t index = unit_start_idx + unit_idx; | |||||
auto nh = index / units_w; | |||||
auto nw = index % units_w; | |||||
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; | |||||
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; | |||||
GI_DISPATCH_CONV_WINOGRAD_BIAS( | |||||
megdnn_fallback_winograd_fp32_F43, cb, float, float, bmode, | |||||
nonline_mode, output_transform_buf, bias, output, transform_mid_buf, | |||||
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, | |||||
nr_units_in_tile, src_dtype, dst_dtype); | |||||
} | |||||
} | |||||
#undef cb | |||||
} | |||||
} // namespace winograd | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -93,7 +93,6 @@ struct InputTransform6X3 { | |||||
#undef cb | #undef cb | ||||
INPUT_TRANSFORM(d, wd); | INPUT_TRANSFORM(d, wd); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#define cb(i) GI_FLOAT32_V2_t ret##i; | #define cb(i) GI_FLOAT32_V2_t ret##i; | ||||
UNROLL_CALL_NOWRAPPER(8, cb); | UNROLL_CALL_NOWRAPPER(8, cb); | ||||
@@ -159,6 +159,10 @@ public: | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | m_gi_winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF43( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
m_gi_winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoFP32WinogradF54( | refhold.emplace_back(new AlgoFP32WinogradF54( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
@@ -217,6 +217,7 @@ public: | |||||
FB_IM2COL, | FB_IM2COL, | ||||
GI_COMMON_WINOGRAD_F23_4X4_FP32, | GI_COMMON_WINOGRAD_F23_4X4_FP32, | ||||
GI_COMMON_WINOGRAD_F63_FP32, | GI_COMMON_WINOGRAD_F63_FP32, | ||||
GI_COMMON_WINOGRAD_F43_FP32, | |||||
GI_COMMON_WINOGRAD_F63_4X4_FP32, | GI_COMMON_WINOGRAD_F63_4X4_FP32, | ||||
GI_COMMON_WINOGRAD_F54_FP32, | GI_COMMON_WINOGRAD_F54_FP32, | ||||
GI_COMMON_WINOGRAD_F45_FP32, | GI_COMMON_WINOGRAD_F45_FP32, | ||||
@@ -379,6 +380,7 @@ private: | |||||
class AlgoFP32WinogradF23_4x4; | class AlgoFP32WinogradF23_4x4; | ||||
class AlgoFP32WinogradF63; | class AlgoFP32WinogradF63; | ||||
class AlgoFP32WinogradF43; | |||||
class AlgoFP32WinogradF63_4x4; | class AlgoFP32WinogradF63_4x4; | ||||
class AlgoFP32WinogradF54; | class AlgoFP32WinogradF54; | ||||
class AlgoFP32WinogradF45; | class AlgoFP32WinogradF45; | ||||
@@ -83,7 +83,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {8, 6, m_tile_size}); | |||||
m_matmul_algo->name(), {8, 6, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -100,7 +100,7 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
m_matmul_algo->name(), {8, 2, m_tile_size}); | |||||
m_matmul_algo->name(), {8, 2, m_tile_size, 3}); | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
@@ -47,6 +47,25 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL) { | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON, CONV_BIAS_WINOGRAD) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_quantized_args(); | |||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||||
handle()); | |||||
checker.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>("WINOGRAD:.*:1:4:.*:3")); | |||||
ConvBiasForward::Param param; | |||||
param.pad_h = 1; | |||||
param.pad_w = 1; | |||||
checker.set_param(param); | |||||
checker.execs( | |||||
{{1, 3, 351, 257}, | |||||
{5, 3, 3, 3}, | |||||
{}, | |||||
{}, | |||||
{}}); // Input, weight, bias, ..., Output | |||||
} | |||||
TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { | TEST_F(ARM_COMMON, CONV_BIAS_RECORD) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args = get_quantized_args(); | std::vector<TestArg> args = get_quantized_args(); | ||||
@@ -987,6 +1006,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F43_F63) { | |||||
#if MEGDNN_AARCH64 | |||||
benchmark_winograd_compare( | |||||
"WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:3", "WINOGRAD:AARCH64_F32K8X12X1:1:6", | |||||
handle(), 3); | |||||
#endif | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); | benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:6", handle(), 3); | ||||
@@ -1005,9 +1031,9 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F54) { | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F45) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F45) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4", handle(), 5); | |||||
benchmark_winograd("WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5", handle(), 5); | |||||
#else | #else | ||||
benchmark_winograd("WINOGRAD:ARMV7_F32:1:4", handle(), 5); | |||||
benchmark_winograd("WINOGRAD:ARMV7_F32:1:4:.*:5", handle(), 5); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -1026,11 +1052,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23) { | |||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F45) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F45) { | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
benchmark_winograd_fp16( | benchmark_winograd_fp16( | ||||
"WINOGRAD:AARCH64_F32K8X12X1:1:4", "WINOGRAD:AARCH64_F16_K8X24X1:1:4", | |||||
handle(), 5); | |||||
"WINOGRAD:AARCH64_F32K8X12X1:1:4:.*:5", | |||||
"WINOGRAD:AARCH64_F16_K8X24X1:1:4:.*:5", handle(), 5); | |||||
#else | #else | ||||
benchmark_winograd_fp16( | benchmark_winograd_fp16( | ||||
"WINOGRAD:ARMV7_F32:1:4", "WINOGRAD:AARCH32_F16_K4X16X1:1:4", handle(), 5); | |||||
"WINOGRAD:ARMV7_F32:1:4:.*:5", "WINOGRAD:AARCH32_F16_K4X16X1:1:4:.*:5", | |||||
handle(), 5); | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F63) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F63) { | ||||
@@ -800,6 +800,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F32_F43) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args = get_winograd_args(3); | |||||
Checker<ConvBiasForward> checker(handle()); | |||||
check_winograd("1:4:32", checker, args); | |||||
} | |||||
//! uncomment it when low precision mode is ok | //! uncomment it when low precision mode is ok | ||||
#if 0 | #if 0 | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | ||||
@@ -921,6 +921,7 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args( | |||||
TensorShape{oc, ic, kernel, kernel}, | TensorShape{oc, ic, kernel, kernel}, | ||||
{1, oc, 1, 1}}); | {1, oc, 1, 1}}); | ||||
}; | }; | ||||
for (size_t ic : {8, 16, 32, 64}) { | for (size_t ic : {8, 16, 32, 64}) { | ||||
for (size_t oc : {8, 16, 32, 64}) { | for (size_t oc : {8, 16, 32, 64}) { | ||||
pack(oc, ic, 56, 56, kernel, kernel / 2); | pack(oc, ic, 56, 56, kernel, kernel / 2); | ||||
@@ -1041,6 +1042,60 @@ void benchmark_winograd_weight_preprocess( | |||||
computations / used_winograd); | computations / used_winograd); | ||||
} | } | ||||
} | } | ||||
void benchmark_winograd_compare( | |||||
const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, | |||||
size_t kernel, size_t pack_size) { | |||||
auto&& args = get_winograd_benchmark_args(kernel, pack_size); | |||||
using namespace conv_bias; | |||||
constexpr size_t RUN = 10; | |||||
Benchmarker<ConvBias, Timer, OprWeightPreprocessBenchmarkProxy<ConvBias>> | |||||
benchmark_winograd(handle); | |||||
benchmark_winograd.set_display(false); | |||||
benchmark_winograd.set_times(RUN); | |||||
for (auto&& arg : args) { | |||||
TensorLayout dst_layout; | |||||
auto opr = handle->create_operator<ConvBias>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout( | |||||
{arg.src, dtype::Float32()}, {arg.filter, dtype::Float32()}, | |||||
{arg.bias, dtype::Float32()}, {}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float computations = dst_layout.total_nr_elems() * arg.filter[1] * | |||||
arg.filter[2] * arg.filter[3] * 2.0 / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
param::Convolution conv_param; | |||||
conv_param.pad_h = arg.param.pad_h; | |||||
conv_param.pad_w = arg.param.pad_w; | |||||
conv_param.stride_h = arg.param.stride_h; | |||||
conv_param.stride_w = arg.param.stride_w; | |||||
benchmark_winograd.set_param(arg.param); | |||||
auto used_winograd1 = | |||||
algo_benchmark< | |||||
ConvBias, OprWeightPreprocessBenchmarkProxy<ConvBias>, Timer>( | |||||
benchmark_winograd, {arg.src, arg.filter, {}, {}, {}}, | |||||
algoA_name) / | |||||
RUN; | |||||
auto used_winograd2 = | |||||
algo_benchmark< | |||||
ConvBias, OprWeightPreprocessBenchmarkProxy<ConvBias>, Timer>( | |||||
benchmark_winograd, {arg.src, arg.filter, {}, {}, {}}, | |||||
algoB_name) / | |||||
RUN; | |||||
printf("%s %s: %s: %f ms %f Gflops %s: %f ms %f GFlops " | |||||
"speedup: " | |||||
"%f\n", | |||||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), algoA_name, | |||||
used_winograd1, computations / used_winograd1, algoB_name, | |||||
used_winograd2, computations / used_winograd2, | |||||
used_winograd2 / used_winograd1); | |||||
} | |||||
} | |||||
#endif // MEGDNN_WITH_BENCHMARK | #endif // MEGDNN_WITH_BENCHMARK | ||||
template <class Checker> | template <class Checker> | ||||
@@ -69,6 +69,9 @@ void benchmark_winograd( | |||||
void benchmark_winograd_weight_preprocess( | void benchmark_winograd_weight_preprocess( | ||||
const char* algo_name, megdnn::Handle* handle, size_t kernel, | const char* algo_name, megdnn::Handle* handle, size_t kernel, | ||||
size_t pack_size = 1); | size_t pack_size = 1); | ||||
void benchmark_winograd_compare( | |||||
const char* algoA_name, const char* algoB_name, megdnn::Handle* handle, | |||||
size_t kernel, size_t pack_size = 1); | |||||
#endif // MEGDNN_WITH_BENCHMARK | #endif // MEGDNN_WITH_BENCHMARK | ||||
template <class Checker> | template <class Checker> | ||||
void check_winograd( | void check_winograd( | ||||