@@ -318,6 +318,79 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { | |||||
} | } | ||||
} // namespace | } // namespace | ||||
void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2( | |||||
const MatrixMulImpl::KernParam& kern_param) { | |||||
MEGDNN_MARK_USED_VAR(kern_param); | |||||
MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_4x16x2, midout_iv(1)) { | |||||
constexpr int cacheline = 64; | |||||
const size_t m = kern_param.M; | |||||
const size_t n = kern_param.N; | |||||
const size_t k = kern_param.K; | |||||
const bool trans_a = kern_param.trA; | |||||
const bool trans_b = kern_param.trB; | |||||
const size_t lda = kern_param.LDA; | |||||
const size_t ldb = kern_param.LDB; | |||||
const size_t ldc = kern_param.LDC; | |||||
auto a_type = kern_param.A_type; | |||||
auto b_type = kern_param.B_type; | |||||
auto c_type = kern_param.C_type; | |||||
const auto a_ptr = kern_param.A<dt_int8>(); | |||||
const auto b_ptr = kern_param.B<dt_int8>(); | |||||
auto c_ptr = kern_param.C<dt_int16>(); | |||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type, | |||||
c_type); | |||||
megdnn::matmul::GemmInterleaved<x86::matmul::gemm_avx2_s8s8s16_4x16x2>( | |||||
m, n, k, trans_a, trans_b, strategy, cacheline) | |||||
.execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_kern( | |||||
const KernSizeParam&) const { | |||||
return gemm_s8s8s16_avx2_4x16x2; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
bool is_ab_same = | |||||
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv(); | |||||
bool is_type_ok = | |||||
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && | |||||
kern_size_param.C_type.enumv() == DTypeEnum::Int16) || | |||||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); | |||||
bool is_mode_ok = | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
is_supported(SIMDType::AVX2); | |||||
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; | |||||
return is_param_ok; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const { | |||||
return true; | |||||
} | |||||
size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( | |||||
const KernSizeParam& kern_param) const { | |||||
constexpr int cacheline = 64; | |||||
const size_t m = kern_param.M; | |||||
const size_t n = kern_param.N; | |||||
const size_t k = kern_param.K; | |||||
const bool trans_a = kern_param.trA; | |||||
const bool trans_b = kern_param.trB; | |||||
auto a_type = kern_param.A_type; | |||||
auto b_type = kern_param.B_type; | |||||
auto c_type = kern_param.C_type; | |||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type, | |||||
c_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2>( | |||||
m, n, k, trans_a, trans_b, strategy, cacheline) | |||||
.get_workspace_size(); | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||||
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8, | |||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | ||||
const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
@@ -6,13 +6,14 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/x86/matrix_mul/opr_impl.h" | |||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
#include "src/x86/matrix_mul/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace x86 { | namespace x86 { | ||||
@@ -71,6 +72,23 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { | |||||
private: | |||||
static void gemm_s8s8s16_avx2_4x16x2( | |||||
const MatrixMulImpl::KernParam& kern_param); | |||||
static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo; | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "X86_INT8X8X16_AVX2"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_x86_algo_type; } | |||||
bool preferred(const KernSizeParam&) const override; | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -6,16 +6,17 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <x86intrin.h> | #include <x86intrin.h> | ||||
#ifdef WIN32 | #ifdef WIN32 | ||||
#include <avxintrin.h> | |||||
#include <smmintrin.h> | |||||
#include <avx2intrin.h> | #include <avx2intrin.h> | ||||
#include <avxintrin.h> | |||||
#include <fmaintrin.h> | #include <fmaintrin.h> | ||||
#include <smmintrin.h> | |||||
#endif | #endif | ||||
#include <cmath> | #include <cmath> | ||||
#include <cstdint> | #include <cstdint> | ||||
@@ -787,13 +788,19 @@ static inline void transpose_4x8_k2_int8_to_int16(const int8_t* inptr0, | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline __v8si _m256_continue_mask_v8si(const int& x) { | static inline __v8si _m256_continue_mask_v8si(const int& x) { | ||||
// clang-format off | |||||
static __v8si map[9] = { | static __v8si map[9] = { | ||||
{0, 0, 0, 0, 0, 0, 0, 0}, {-1, 0, 0, 0, 0, 0, 0, 0}, | |||||
{-1, -1, 0, 0, 0, 0, 0, 0}, {-1, -1, -1, 0, 0, 0, 0, 0}, | |||||
{-1, -1, -1, -1, 0, 0, 0, 0}, {-1, -1, -1, -1, -1, 0, 0, 0}, | |||||
{-1, -1, -1, -1, -1, -1, 0, 0}, {-1, -1, -1, -1, -1, -1, -1, 0}, | |||||
{00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1}}; | {-1, -1, -1, -1, -1, -1, -1, -1}}; | ||||
return map[x]; | return map[x]; | ||||
// clang-format on | |||||
} | } | ||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline __m256i _m256_continue_mask(const int& x) { | static inline __m256i _m256_continue_mask(const int& x) { | ||||
@@ -801,6 +808,30 @@ static inline __m256i _m256_continue_mask(const int& x) { | |||||
} | } | ||||
MEGDNN_ATTRIBUTE_TARGET("sse2") | MEGDNN_ATTRIBUTE_TARGET("sse2") | ||||
static inline __m128i _mm_continue_mask(const int& x) { | |||||
static __v16qi map[17] = { | |||||
{00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 00}, | |||||
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}, | |||||
}; | |||||
return (__m128i)map[x]; | |||||
} | |||||
MEGDNN_ATTRIBUTE_TARGET("sse2") | |||||
static inline void transpose_4xk_int8_to_int16_pad(const int8_t* inptr0, | static inline void transpose_4xk_int8_to_int16_pad(const int8_t* inptr0, | ||||
const int8_t* inptr1, | const int8_t* inptr1, | ||||
const int8_t* inptr2, | const int8_t* inptr2, | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -18,10 +19,9 @@ using namespace megdnn; | |||||
using namespace x86; | using namespace x86; | ||||
using namespace x86::matmul; | using namespace x86::matmul; | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2); | |||||
void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_at(out, in, ldin, y0, | matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_at(out, in, ldin, y0, | ||||
ymax, k0, kmax); | ymax, k0, kmax); | ||||
@@ -30,10 +30,8 @@ void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, | |||||
ymax, k0, kmax); | ymax, k0, kmax); | ||||
} | } | ||||
} | } | ||||
void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_bt(out, in, ldin, x0, | matmul_avx2_4x16x2::gemm_s8s8s32_avx2_4x16x2_pack_bt(out, in, ldin, x0, | ||||
xmax, k0, kmax); | xmax, k0, kmax); | ||||
@@ -42,20 +40,11 @@ void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
xmax, k0, kmax); | xmax, k0, kmax); | ||||
} | } | ||||
} | } | ||||
void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, | |||||
const dt_int8* pack_b_ptr, size_t m, | |||||
size_t n, size_t k, dt_int32* c_ptr, | |||||
size_t ldc, bool is_first_k, | |||||
const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
megdnn_assert(is_first_k == true); | |||||
template <typename CType> | |||||
static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||||
const dt_int8* pack_b_ptr, size_t m, size_t n, | |||||
size_t k, CType* c_ptr, size_t ldc, | |||||
bool is_first_k) { | |||||
constexpr size_t m_tile = 4; | constexpr size_t m_tile = 4; | ||||
constexpr size_t n_tile = 16; | constexpr size_t n_tile = 16; | ||||
constexpr size_t k_tile = 2; | constexpr size_t k_tile = 2; | ||||
@@ -109,4 +98,62 @@ void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, | |||||
} | } | ||||
} | } | ||||
} | } | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s32_4x16x2); | |||||
void gemm_avx2_s8s8s32_4x16x2::pack_A(dt_int16* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); | |||||
} | |||||
void gemm_avx2_s8s8s32_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); | |||||
} | |||||
void gemm_avx2_s8s8s32_4x16x2::kern(const dt_int16* pack_a_ptr, | |||||
const dt_int8* pack_b_ptr, size_t m, | |||||
size_t n, size_t k, dt_int32* c_ptr, | |||||
size_t ldc, bool is_first_k, | |||||
const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
megdnn_assert(is_first_k == true); | |||||
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); | |||||
} | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_avx2_s8s8s16_4x16x2); | |||||
void gemm_avx2_s8s8s16_4x16x2::pack_A(dt_int16* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); | |||||
} | |||||
void gemm_avx2_s8s8s16_4x16x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); | |||||
} | |||||
void gemm_avx2_s8s8s16_4x16x2::kern(const dt_int16* pack_a_ptr, | |||||
const dt_int8* pack_b_ptr, size_t m, | |||||
size_t n, size_t k, dt_int16* c_ptr, | |||||
size_t ldc, bool is_first_k, | |||||
const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS16)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
megdnn_assert(is_first_k == true); | |||||
gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include <immintrin.h> | #include <immintrin.h> | ||||
@@ -20,11 +21,47 @@ namespace megdnn { | |||||
namespace x86 { | namespace x86 { | ||||
namespace matmul_avx2_4x16x2 { | namespace matmul_avx2_4x16x2 { | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | |||||
void store_overflow(void* ptr, __m256i a); | |||||
template <> | |||||
void store_overflow<int16_t>(void* ptr, __m256i a) { | |||||
static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0); | |||||
a = _mm256_shufflelo_epi16(a, 0x08); | |||||
a = _mm256_shufflehi_epi16(a, 0x08); | |||||
a = _mm256_permutevar8x32_epi32(a, idx); | |||||
_mm_storeu_si128((__m128i*)ptr, _mm256_extractf128_si256(a, 0)); | |||||
} | |||||
template <> | |||||
void store_overflow<int32_t>(void* ptr, __m256i a) { | |||||
_mm256_storeu_si256((__m256i*)(ptr), a); | |||||
} | |||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | |||||
void store_overflow(void* ptr, __m256i a, int remain); | |||||
template <> | |||||
void store_overflow<int16_t>(void* ptr, __m256i a, int remain) { | |||||
__m128i mask = _mm_continue_mask(remain * sizeof(int16_t)); | |||||
static __m256i idx = _mm256_setr_epi32(0, 2, 4, 6, 0, 0, 0, 0); | |||||
a = _mm256_shufflelo_epi16(a, 0x08); | |||||
a = _mm256_shufflehi_epi16(a, 0x08); | |||||
a = _mm256_permutevar8x32_epi32(a, idx); | |||||
_mm_maskmoveu_si128(_mm256_extractf128_si256(a, 0), mask, | |||||
reinterpret_cast<char*>(ptr)); | |||||
} | |||||
template <> | |||||
void store_overflow<int32_t>(void* ptr, __m256i a, int remain) { | |||||
__m256i mask = _m256_continue_mask(remain); | |||||
_mm256_maskstore_epi32(reinterpret_cast<int32_t*>(ptr), mask, a); | |||||
} | |||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, | static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, | ||||
const int8_t* pack_b_ptr, | const int8_t* pack_b_ptr, | ||||
int32_t* c_ptr, | |||||
CType* c_ptr, | |||||
const uint32_t ldc, | const uint32_t ldc, | ||||
const uint32_t k) { | const uint32_t k) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -104,19 +141,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2(const int16_t* pack_a_ptr, | |||||
pack_b_ptr += 32; | pack_b_ptr += 32; | ||||
} | } | ||||
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); | |||||
store_overflow<CType>(c_ptr, c_vec[0]); | |||||
store_overflow<CType>(c_ptr + 8, c_vec[1]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7]); | |||||
} | } | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( | static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( | ||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||||
const uint32_t ldc, const uint32_t k, const uint32_t remain_n) { | const uint32_t ldc, const uint32_t k, const uint32_t remain_n) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -173,15 +210,15 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_n( | |||||
pack_b_ptr += 32; | pack_b_ptr += 32; | ||||
} | } | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); | |||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); | |||||
store_overflow<CType>(c_ptr, c_vec[0], remain_n); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n); | |||||
} | } | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( | static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( | ||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||||
const uint32_t ldc, const uint32_t k, const uint32_t remain_m, | const uint32_t ldc, const uint32_t k, const uint32_t remain_m, | ||||
uint32_t remain_n) { | uint32_t remain_n) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -239,29 +276,29 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_n8_remain_m_n( | |||||
pack_b_ptr += 32; | pack_b_ptr += 32; | ||||
} | } | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); | |||||
store_overflow<CType>(c_ptr, c_vec[0], remain_n); | |||||
switch (remain_m) { | switch (remain_m) { | ||||
case 2: | case 2: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
break; | break; | ||||
case 3: | case 3: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
break; | break; | ||||
case 4: | case 4: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n); | |||||
break; | break; | ||||
default: | default: | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( | static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( | ||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||||
const uint32_t ldc, const uint32_t k, const uint32_t remain_m) { | const uint32_t ldc, const uint32_t k, const uint32_t remain_m) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -339,34 +376,36 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m( | |||||
pack_a_ptr += 8; | pack_a_ptr += 8; | ||||
pack_b_ptr += 32; | pack_b_ptr += 32; | ||||
} | } | ||||
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 8), c_vec[1]); | |||||
store_overflow<CType>(c_ptr, c_vec[0]); | |||||
store_overflow<CType>(c_ptr + 8, c_vec[1]); | |||||
switch (remain_m) { | switch (remain_m) { | ||||
case 2: | case 2: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]); | |||||
break; | break; | ||||
case 3: | case 3: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]); | |||||
break; | break; | ||||
case 4: | case 4: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc + 8), c_vec[3]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc + 8), c_vec[5]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc + 8), c_vec[7]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7]); | |||||
default: | default: | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( | static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( | ||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||||
const uint32_t ldc, const uint32_t k, uint32_t remain_n) { | const uint32_t ldc, const uint32_t k, uint32_t remain_n) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -446,29 +485,28 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_n( | |||||
} | } | ||||
if (remain_n >= 8) { | if (remain_n >= 8) { | ||||
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); | |||||
store_overflow<CType>(c_ptr, c_vec[0]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||||
remain_n -= 8; | remain_n -= 8; | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]); | |||||
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, c_vec[5]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, c_vec[7]); | |||||
store_overflow<CType>(c_ptr + 8, c_vec[1], remain_n); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5], remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7], remain_n); | |||||
} | } | ||||
} else { | } else { | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); | |||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); | |||||
store_overflow<CType>(c_ptr, c_vec[0], remain_n); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n); | |||||
} | } | ||||
} | } | ||||
template <typename CType> | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") | MEGDNN_ATTRIBUTE_TARGET("avx2") | ||||
static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( | static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( | ||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||||
const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||||
const uint32_t ldc, const uint32_t k, const uint32_t remain_m, | const uint32_t ldc, const uint32_t k, const uint32_t remain_m, | ||||
uint32_t remain_n) { | uint32_t remain_n) { | ||||
constexpr uint32_t k_step = 2; | constexpr uint32_t k_step = 2; | ||||
@@ -549,19 +587,19 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( | |||||
} | } | ||||
if (remain_n >= 8) { | if (remain_n >= 8) { | ||||
_mm256_storeu_si256((__m256i*)(c_ptr), c_vec[0]); | |||||
store_overflow<CType>(c_ptr, c_vec[0]); | |||||
switch (remain_m) { | switch (remain_m) { | ||||
case 2: | case 2: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
break; | break; | ||||
case 3: | case 3: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
break; | break; | ||||
case 4: | case 4: | ||||
_mm256_storeu_si256((__m256i*)(c_ptr + ldc), c_vec[2]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 2 * ldc), c_vec[4]); | |||||
_mm256_storeu_si256((__m256i*)(c_ptr + 3 * ldc), c_vec[6]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||||
break; | break; | ||||
default: | default: | ||||
break; | break; | ||||
@@ -569,43 +607,41 @@ static inline void kern_gemm_s8s8s32_avx2_4x16x2_remain_m_n( | |||||
remain_n -= 8; | remain_n -= 8; | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr + 8), mask, c_vec[1]); | |||||
store_overflow<CType>(c_ptr + 8, c_vec[1], remain_n); | |||||
switch (remain_m) { | switch (remain_m) { | ||||
case 2: | case 2: | ||||
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n); | |||||
break; | break; | ||||
case 3: | case 3: | ||||
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, | |||||
c_vec[5]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5], | |||||
remain_n); | |||||
break; | break; | ||||
case 4: | case 4: | ||||
_mm256_maskstore_epi32((c_ptr + ldc + 8), mask, c_vec[3]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc + 8), mask, | |||||
c_vec[5]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc + 8), mask, | |||||
c_vec[7]); | |||||
store_overflow<CType>(c_ptr + ldc + 8, c_vec[3], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc + 8, c_vec[5], | |||||
remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc + 8, c_vec[7], | |||||
remain_n); | |||||
break; | break; | ||||
default: | default: | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
__m256i mask = _m256_continue_mask(remain_n); | |||||
_mm256_maskstore_epi32((c_ptr), mask, c_vec[0]); | |||||
store_overflow<CType>(c_ptr, c_vec[0], remain_n); | |||||
switch (remain_m) { | switch (remain_m) { | ||||
case 2: | case 2: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
break; | break; | ||||
case 3: | case 3: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
break; | break; | ||||
case 4: | case 4: | ||||
_mm256_maskstore_epi32((c_ptr + ldc), mask, c_vec[2]); | |||||
_mm256_maskstore_epi32((c_ptr + 2 * ldc), mask, c_vec[4]); | |||||
_mm256_maskstore_epi32((c_ptr + 3 * ldc), mask, c_vec[6]); | |||||
store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||||
store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||||
store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n); | |||||
break; | break; | ||||
default: | default: | ||||
break; | break; | ||||
@@ -833,4 +869,5 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out, | |||||
} // namespace x86 | } // namespace x86 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -29,6 +30,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, | |||||
4, 16, 2, false, false, | 4, 16, 2, false, false, | ||||
gemm_avx2_s8s8s32_4x16x2); | gemm_avx2_s8s8s32_4x16x2); | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32, | |||||
4, 16, 2, false, false, | |||||
gemm_avx2_s8s8s16_4x16x2); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, | MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, | ||||
4, 8, 2, false, false, | 4, 8, 2, false, false, | ||||
gemm_sse_s8s8s32_4x8x2); | gemm_sse_s8s8s32_4x8x2); | ||||
@@ -37,6 +37,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8x8x32AVX2M4N16K2 algoint8x8x32avx2_m4n16k2; | AlgoInt8x8x32AVX2M4N16K2 algoint8x8x32avx2_m4n16k2; | ||||
AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; | AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; | ||||
AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; | AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; | ||||
AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; | |||||
AlgoF32MK8_8x8 algof32mk8_8x8; | AlgoF32MK8_8x8 algof32mk8_8x8; | ||||
public: | public: | ||||
@@ -47,6 +48,7 @@ public: | |||||
#endif | #endif | ||||
} | } | ||||
all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); | all_algos.emplace_back(&algoint8x8x32avx2_m4n16k2); | ||||
all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); | |||||
all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); | all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); | ||||
all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); | all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); | ||||
all_algos.emplace_back(&algof32mk8_8x8); | all_algos.emplace_back(&algof32mk8_8x8); | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -54,6 +55,7 @@ protected: | |||||
class AlgoInt8x8x32AVX2M2N4K16; | class AlgoInt8x8x32AVX2M2N4K16; | ||||
class AlgoInt8x8x32AVX2M4N16K2; | class AlgoInt8x8x32AVX2M4N16K2; | ||||
class AlgoInt8x8x32SSEM4N8K2; | class AlgoInt8x8x32SSEM4N8K2; | ||||
class AlgoInt8x8x16AVX2; | |||||
class AlgoPack; | class AlgoPack; | ||||
class AlgoF32MK8_8x8; | class AlgoF32MK8_8x8; | ||||
}; | }; | ||||
@@ -752,7 +752,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_DIRECT_STRIDE2) { | |||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -807,6 +807,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
.set_param(arg.param) \ | .set_param(arg.param) \ | ||||
.execs({arg.src, arg.filter, {}, {}, {}}); \ | .execs({arg.src, arg.filter, {}, {}, {}}); \ | ||||
} | } | ||||
#define cb2(algo_name) \ | |||||
checker.set_before_exec_callback( \ | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \ | |||||
checker.set_dtype(0, dtype::Int8()); \ | |||||
checker.set_dtype(1, dtype::Int8()); \ | |||||
checker.set_dtype(2, dtype::Int16()); \ | |||||
checker.set_dtype(4, dtype::Int16()); \ | |||||
for (auto&& arg : args) { \ | |||||
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); \ | |||||
} | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
if (megdnn::x86::is_supported(x86::SIMDType::VNNI)) { | if (megdnn::x86::is_supported(x86::SIMDType::VNNI)) { | ||||
@@ -821,12 +831,14 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
if (megdnn::x86::is_supported(x86::SIMDType::AVX2)) { | if (megdnn::x86::is_supported(x86::SIMDType::AVX2)) { | ||||
cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"); | cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_2X4X16"); | ||||
cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"); | cb("IM2COLMATMUL:X86_INT8X8X32_AVX2_4X16X2"); | ||||
cb2("IM2COLMATMUL:X86_INT8X8X16_AVX2"); | |||||
} | } | ||||
if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { | if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { | ||||
cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); | cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); | ||||
} | } | ||||
#undef cb | #undef cb | ||||
#undef cb2 | |||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { | TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { | ||||
@@ -1964,6 +1976,39 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { | |||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_8816) { | |||||
constexpr size_t RUNS = 30; | |||||
param::ConvBias param; | |||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.sparse = param::ConvBias::Sparse::DENSE; | |||||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | |||||
dtype::Int16(), dtype::Int16()}; | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>> | |||||
shapes_and_computation; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | |||||
size_t FS) { | |||||
param.pad_h = FS / 2; | |||||
param.pad_w = FS / 2; | |||||
SmallVector<TensorShape> shapes{ | |||||
{N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; | |||||
TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, | |||||
(W + 2 * param.pad_w - FS) / param.stride_w + 1}; | |||||
float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; | |||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | |||||
}; | |||||
bench_case(1, 48, 192, 15, 15, 1); | |||||
std::string algo_name = "IM2COLMATMUL:X86_INT8X8X16_AVX2"; | |||||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | |||||
{4, {4, 5, 6, 7}}, {1, {4}}, data_type); | |||||
shapes_and_computation.clear(); | |||||
} | |||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, | TEST_F(X86_BENCHMARK_MULTI_THREADS, | ||||
BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2) { | BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8_STRIDE2) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
@@ -1985,7 +2030,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, | |||||
SmallVector<TensorShape> shapes{ | SmallVector<TensorShape> shapes{ | ||||
{N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; | {N, IC, H, W}, {OC, IC, FS, FS}, {}, {}, {}}; | ||||
TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, | TensorShape dst{N, OC, (H + 2 * param.pad_h - FS) / param.stride_h + 1, | ||||
(W + 2 * param.pad_w - FS) / param.pad_w + 1}; | |||||
(W + 2 * param.pad_w - FS) / param.stride_w + 1}; | |||||
float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; | float computations = (IC * FS * FS * dst.total_nr_elems() * 2) * 1e-6; | ||||
shapes_and_computation.push_back(std::make_pair(shapes, computations)); | shapes_and_computation.push_back(std::make_pair(shapes, computations)); | ||||
}; | }; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/x86/fixture.h" | #include "test/x86/fixture.h" | ||||
@@ -369,6 +370,63 @@ TEST_F(X86, CONVOLUTION_DIRECT_MKLDNN_C8) { | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) { | |||||
using namespace convolution; | |||||
using Param = param::Convolution; | |||||
std::vector<TestArg> args; | |||||
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||||
size_t stride, size_t group = 1) { | |||||
Param param; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = kernel / 2; | |||||
param.pad_w = kernel / 2; | |||||
if (group > 1) { | |||||
param.sparse = param::Convolution::Sparse::GROUP; | |||||
args.emplace_back( | |||||
param, TensorShape{1, ic, h, w}, | |||||
TensorShape{group, oc / group, ic / group, kernel, kernel}); | |||||
} else { | |||||
param.sparse = param::Convolution::Sparse::DENSE; | |||||
args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
TensorShape{oc, ic, kernel, kernel}); | |||||
} | |||||
}; | |||||
run(48, 96, 15, 15, 1, 1); | |||||
run(64, 64, 60, 60, 3, 1); | |||||
run(64, 64, 60, 60, 3, 1, 64); | |||||
constexpr size_t RUN = 30; | |||||
Benchmarker<Convolution> benchmark(handle()); | |||||
benchmark.set_dtype(0, dtype::Int8()) | |||||
.set_dtype(1, dtype::Int8()) | |||||
.set_dtype(2, dtype::Int16()); | |||||
benchmark.set_display(false); | |||||
benchmark.set_times(RUN); | |||||
for (auto&& arg : args) { | |||||
TensorLayout dst_layout; | |||||
auto opr = handle()->create_operator<Convolution>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout({arg.src, dtype::Float32()}, | |||||
{arg.filter, dtype::Float32()}, dst_layout); | |||||
//! dst.nr_elems * IC * FH * FW * 2 | |||||
float icpg = arg.filter.ndim == 4 ? arg.filter[1] : arg.filter[2]; | |||||
float filter = arg.filter.ndim == 4 ? arg.filter[2] : arg.filter[3]; | |||||
float computations = dst_layout.total_nr_elems() * icpg * filter * | |||||
filter * 2.0 / (1024 * 1024 * 1024) * 1e3; | |||||
auto used_int = | |||||
benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / | |||||
RUN; | |||||
printf("%s %s: int: %f ms %f Gflops \n", arg.src.to_string().c_str(), | |||||
arg.filter.to_string().c_str(), used_int, | |||||
computations / used_int); | |||||
} | |||||
} | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { | TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { | ||||
using namespace convolution; | using namespace convolution; | ||||
@@ -419,7 +477,6 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x32_MKLDNN) { | |||||
float computations = dst_layout.total_nr_elems() * arg.filter[1] * | float computations = dst_layout.total_nr_elems() * arg.filter[1] * | ||||
arg.filter[2] * arg.filter[3] * 2.0 / | arg.filter[2] * arg.filter[3] * 2.0 / | ||||
(1024 * 1024 * 1024) * 1e3; | (1024 * 1024 * 1024) * 1e3; | ||||
auto used_int = | auto used_int = | ||||
benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / | benchmark.set_param(arg.param).exec({arg.src, arg.filter, {}}) / | ||||
RUN; | RUN; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/x86/fixture.h" | #include "test/x86/fixture.h" | ||||
@@ -47,6 +48,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "X86_INT8X8X32_AVX2_4X16X2"); | handle(), "X86_INT8X8X32_AVX2_4X16X2"); | ||||
} | } | ||||
TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||||
handle(), "X86_INT8X8X16_AVX2"); | |||||
} | |||||
TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { | TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "X86_INT8X8X32_SSE_4X8X2"); | handle(), "X86_INT8X8X32_SSE_4X8X2"); | ||||
@@ -116,6 +121,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { | |||||
benchmarker_avx2_4x16x2.set_before_exec_callback( | benchmarker_avx2_4x16x2.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2")); | AlgoChecker<MatrixMul>("X86_INT8X8X32_AVX2_4X16X2")); | ||||
Benchmarker<MatrixMul> benchmarker_avx2_4x16x2_8816(handle()); | |||||
benchmarker_avx2_4x16x2_8816.set_display(false) | |||||
.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int16{}) | |||||
.set_rng(0, rng.get()) | |||||
.set_rng(1, rng.get()); | |||||
benchmarker_avx2_4x16x2_8816.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2")); | |||||
Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle()); | Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle()); | ||||
benchmarker_avx2_2x4x16.set_display(false) | benchmarker_avx2_2x4x16.set_display(false) | ||||
.set_times(RUNS) | .set_times(RUNS) | ||||
@@ -183,6 +199,12 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { | |||||
<< "k2_speed_up " << float_used / avx2_used_4x16x2 | << "k2_speed_up " << float_used / avx2_used_4x16x2 | ||||
<< ", k16_speed_up " << float_used / avx2_used_2x4x16 | << ", k16_speed_up " << float_used / avx2_used_2x4x16 | ||||
<< ","; | << ","; | ||||
auto avx2_used_4x16x2_8816 = | |||||
benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) / | |||||
RUNS; | |||||
std::cout << "avx2_8816: " << avx2_used_4x16x2_8816 | |||||
<< " ms, 8816 throughput " | |||||
<< computations / avx2_used_4x16x2_8816 << " Gflops,"; | |||||
} | } | ||||
if (is_supported(SIMDType::SSE4_1)) { | if (is_supported(SIMDType::SSE4_1)) { | ||||
auto sse_used = | auto sse_used = | ||||