GitOrigin-RevId: 288792de42
tags/v0.5.0
@@ -6,16 +6,18 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/x86/conv_bias/int8/algos.h" | #include "src/x86/conv_bias/int8/algos.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" | |||||
#include "src/x86/conv_bias/int8/avx2_chanwise_stride2.h" | |||||
#include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h" | #include "src/x86/conv_bias/int8/avx2_direct_conv_stride1.h" | ||||
#include "src/x86/conv_bias/int8/avx2_direct_conv_stride2.h" | #include "src/x86/conv_bias/int8/avx2_direct_conv_stride2.h" | ||||
#include "src/x86/conv_bias/int8/avx2_chanwise_stride1.h" | |||||
#include "src/x86/conv_bias/opr_impl.h" | #include "src/x86/conv_bias/opr_impl.h" | ||||
#include "src/x86/conv_bias/postprocess_helper.h" | #include "src/x86/conv_bias/postprocess_helper.h" | ||||
#include "src/x86/handle.h" | #include "src/x86/handle.h" | ||||
@@ -38,6 +40,7 @@ bool ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::usable( | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
bool aviliable = | bool aviliable = | ||||
(param.bias_mode != BiasMode::BIAS) && | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
@@ -61,12 +64,12 @@ WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_bundle( | |||||
size_t IH2, IW2, OH2, OW2; | size_t IH2, IW2, OH2, OW2; | ||||
size_t src_size = 0, dst_size = 0, int32_temp = 0; | size_t src_size = 0, dst_size = 0, int32_temp = 0; | ||||
avx2_chanwise_stride1::get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
if (avx2_chanwise_stride1::need_src_copy(param)) { | |||||
if (need_src_copy(param)) { | |||||
src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; | src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; | ||||
} | } | ||||
if (avx2_chanwise_stride1::need_dst_copy(param)) { | |||||
if (need_dst_copy(param)) { | |||||
dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; | dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; | ||||
} | } | ||||
bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
@@ -91,6 +94,66 @@ ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::get_kimpls( | |||||
return avx2_chanwise_stride1::get_kimpls(param, bundle); | return avx2_chanwise_stride1::get_kimpls(param, bundle); | ||||
} | } | ||||
bool ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::usable( | |||||
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
bool aviliable = | |||||
(param.bias_mode != BiasMode::BIAS) && | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||||
(((param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int32) || | |||||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32)))) && | |||||
fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 && | |||||
fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.stride[0] == 2 && | |||||
fm.stride[1] == 2 && (fm.icpg == 1) && (fm.ocpg == 1) && | |||||
is_supported(SIMDType::AVX2); | |||||
return aviliable; | |||||
} | |||||
WorkspaceBundle ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_bundle( | |||||
const NCBKernSizeParam& param) { | |||||
size_t nr_threads = param.nr_threads; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
size_t src_size = 0, dst_size = 0, int32_temp = 0; | |||||
get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
if (need_src_copy(param)) { | |||||
src_size = IH2 * IW2 * sizeof(int8_t) * nr_threads; | |||||
} | |||||
if (need_dst_copy(param)) { | |||||
dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; | |||||
} | |||||
bool dst_need_convert = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
if (dst_need_convert) { | |||||
int32_temp = OH2 * OW2 * sizeof(int32_t) * nr_threads; | |||||
} | |||||
return dst_need_convert | |||||
? WorkspaceBundle(nullptr, {src_size, dst_size, int32_temp}) | |||||
: WorkspaceBundle(nullptr, {src_size, dst_size}); | |||||
} | |||||
size_t ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_workspace( | |||||
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
SmallVector<fallback::ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | |||||
auto bundle = get_bundle(param); | |||||
return avx2_chanwise_stride2::get_kimpls(param, bundle); | |||||
} | |||||
bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::usable( | bool ConvBiasImpl::AlgoDirectAvx2Stride1Int8::usable( | ||||
FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, | FallbackConvBiasImpl* /*opr*/, const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/x86/conv_bias/opr_impl.h" | #include "src/x86/conv_bias/opr_impl.h" | ||||
@@ -36,6 +37,28 @@ public: | |||||
void* type() const override; | void* type() const override; | ||||
}; | }; | ||||
/* ===================== avx2 stride2 chanwise algo ===================== */ | |||||
class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"; | |||||
} | |||||
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(FallbackConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
fallback::ConvBiasImpl*, | |||||
const NCBKernSizeParam& param) const override { | |||||
return get_kimpls(param); | |||||
} | |||||
void* type() const override; | |||||
}; | |||||
/* ===================== avx2 stride1 direct algo ===================== */ | /* ===================== avx2 stride1 direct algo ===================== */ | ||||
class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase { | class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
@@ -125,7 +148,7 @@ public: | |||||
void* type() const override; | void* type() const override; | ||||
}; | }; | ||||
#endif | #endif | ||||
/* ===================== avx2 int8 direct conv stride2 algo ===================== */ | |||||
/* ================== avx2 int8 direct conv stride2 algo ================== */ | |||||
class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
@@ -21,8 +21,6 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace x86 { | namespace x86 { | ||||
namespace avx2_chanwise_stride1 { | |||||
#define load_filter(i) __m128i k_##i = _mm_set1_epi8(*(filter + i)); | #define load_filter(i) __m128i k_##i = _mm_set1_epi8(*(filter + i)); | ||||
#define load_src0(i) \ | #define load_src0(i) \ | ||||
__m256i cvt16_src##i##0 = _mm256_cvtepi8_epi16_from_ptr(r##i); | __m256i cvt16_src##i##0 = _mm256_cvtepi8_epi16_from_ptr(r##i); | ||||
@@ -40,6 +38,15 @@ namespace avx2_chanwise_stride1 { | |||||
__m256i cvt16_src##i##6 = _mm256_cvtepi8_epi16_from_ptr(r##i + 6); | __m256i cvt16_src##i##6 = _mm256_cvtepi8_epi16_from_ptr(r##i + 6); | ||||
#define load_src7(i) \ | #define load_src7(i) \ | ||||
__m256i cvt16_src##i##7 = _mm256_cvtepi8_epi16_from_ptr(r##i + 7); | __m256i cvt16_src##i##7 = _mm256_cvtepi8_epi16_from_ptr(r##i + 7); | ||||
#define load_src16(i) \ | |||||
__m256i cvt16_src##i##16 = _mm256_cvtepi8_epi16_from_ptr(r##i + 16); | |||||
#define load_src18(i) \ | |||||
__m256i cvt16_src##i##18 = _mm256_cvtepi8_epi16_from_ptr(r##i + 18); | |||||
#define load_src20(i) \ | |||||
__m256i cvt16_src##i##20 = _mm256_cvtepi8_epi16_from_ptr(r##i + 20); | |||||
#define load_src22(i) \ | |||||
__m256i cvt16_src##i##22 = _mm256_cvtepi8_epi16_from_ptr(r##i + 22); | |||||
namespace avx2_chanwise_stride1 { | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> | template <BiasMode bias_mode, bool is_quantized, typename Op> | ||||
void avx2_chanwise_direct_stride1_2x2_int8(const int8_t* src, | void avx2_chanwise_direct_stride1_2x2_int8(const int8_t* src, | ||||
@@ -1534,16 +1541,6 @@ void avx2_chanwise_direct_stride1_7x7_int8(const int8_t* src, | |||||
r6 += tail_step; | r6 += tail_step; | ||||
} | } | ||||
} | } | ||||
#undef load_filter | |||||
#undef load_src0 | |||||
#undef load_src1 | |||||
#undef load_src2 | |||||
#undef load_src3 | |||||
#undef load_src4 | |||||
#undef load_src5 | |||||
#undef load_src6 | |||||
#undef load_src7 | |||||
#define INSTANTIATION(stride, i, bias, is_quantized, Op) \ | #define INSTANTIATION(stride, i, bias, is_quantized, Op) \ | ||||
template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ | template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ | ||||
bias, is_quantized, Op>(const int8_t*, const int8_t*, \ | bias, is_quantized, Op>(const int8_t*, const int8_t*, \ | ||||
@@ -1587,6 +1584,697 @@ FOR_STRIDE | |||||
#undef FOR_OP | #undef FOR_OP | ||||
#undef INSTANTIATION | #undef INSTANTIATION | ||||
} // namespace avx2_chanwise_stride1 | } // namespace avx2_chanwise_stride1 | ||||
namespace avx2_chanwise_stride2 { | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> | |||||
void avx2_chanwise_direct_stride2_2x2_int8(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
size_t tail_step = IW - OW * 2; | |||||
int8_t* dst0 = dst; | |||||
int32_t* out_ptr0 = temp; | |||||
const int8_t* r0 = src; | |||||
const int8_t* r1 = src + IW; | |||||
UNROLL_CALL0(4, load_filter) | |||||
#define pack_filter(i, j) __m128i k_##i##j = _mm_unpacklo_epi8(k_##i, k_##j) | |||||
pack_filter(0, 1); | |||||
pack_filter(2, 3); | |||||
__m256i bias_val; | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias_val = _mm256_set1_epi32(*(bias)); | |||||
} else { | |||||
bias_val = _mm256_set1_epi32(0); | |||||
} | |||||
#define cvt_filter(i, j) __m256i filter_##i##j = _mm256_cvtepi8_epi16(k_##i##j) | |||||
cvt_filter(0, 1); | |||||
cvt_filter(2, 3); | |||||
size_t width = OW >> 4; | |||||
for (size_t h = 0; h < OH; h++) { | |||||
for (size_t w = 0; w < width; w++) { | |||||
UNROLL_CALL0(2, load_src0) | |||||
UNROLL_CALL0(2, load_src16) | |||||
__m256i t0_left, t0_right, t1_left, t1_right, sum_left, sum_right; | |||||
t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); | |||||
t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); | |||||
t1_left = _mm256_madd_epi16(cvt16_src10, filter_23); | |||||
t1_right = _mm256_madd_epi16(cvt16_src116, filter_23); | |||||
sum_left = _mm256_add_epi32(t0_left, t1_left); | |||||
sum_right = _mm256_add_epi32(t0_right, t1_right); | |||||
sum_left = _mm256_add_epi32(sum_left, bias_val); | |||||
sum_right = _mm256_add_epi32(sum_right, bias_val); | |||||
if (is_quantized) { | |||||
op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0)); | |||||
} else { | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); | |||||
} | |||||
r0 += 32; | |||||
r1 += 32; | |||||
dst0 += 16; | |||||
out_ptr0 += 16; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
} | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
#undef pack_filter | |||||
#undef cvt_filter | |||||
} | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> | |||||
void avx2_chanwise_direct_stride2_3x3_int8(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
size_t tail_step = IW - OW * 2; | |||||
int32_t* out_ptr0 = temp; | |||||
int8_t* dst0 = dst; | |||||
const int8_t* r0 = src; | |||||
const int8_t* r1 = src + IW; | |||||
const int8_t* r2 = src + 2 * IW; | |||||
uint8_t fill_zero = 0; | |||||
UNROLL_CALL0(9, load_filter) | |||||
__m128i k_fill = _mm_set1_epi8(fill_zero); | |||||
__m128i k01 = _mm_unpacklo_epi8(k_0, k_1); | |||||
__m128i k20 = _mm_unpacklo_epi8(k_2, k_fill); | |||||
__m128i k34 = _mm_unpacklo_epi8(k_3, k_4); | |||||
__m128i k50 = _mm_unpacklo_epi8(k_5, k_fill); | |||||
__m128i k67 = _mm_unpacklo_epi8(k_6, k_7); | |||||
__m128i k80 = _mm_unpacklo_epi8(k_8, k_fill); | |||||
__m256i bias_val; | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias_val = _mm256_set1_epi32(*(bias)); | |||||
} else { | |||||
bias_val = _mm256_set1_epi32(0); | |||||
} | |||||
//! cvt i8 --> i16 | |||||
__m256i filter_01 = _mm256_cvtepi8_epi16(k01); | |||||
__m256i filter_20 = _mm256_cvtepi8_epi16(k20); | |||||
__m256i filter_34 = _mm256_cvtepi8_epi16(k34); | |||||
__m256i filter_50 = _mm256_cvtepi8_epi16(k50); | |||||
__m256i filter_67 = _mm256_cvtepi8_epi16(k67); | |||||
__m256i filter_80 = _mm256_cvtepi8_epi16(k80); | |||||
size_t width = OW >> 4; | |||||
for (size_t h = 0; h < OH; h++) { | |||||
for (size_t w = 0; w < width; w++) { | |||||
UNROLL_CALL0(3, load_src0) | |||||
UNROLL_CALL0(3, load_src2) | |||||
UNROLL_CALL0(3, load_src16) | |||||
UNROLL_CALL0(3, load_src18) | |||||
__m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, | |||||
t2_right, sum_left, sum_right; | |||||
t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src02, filter_20); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src018, filter_20); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
t1_left = _mm256_madd_epi16(cvt16_src10, filter_34); | |||||
temp = _mm256_madd_epi16(cvt16_src12, filter_50); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
t1_right = _mm256_madd_epi16(cvt16_src116, filter_34); | |||||
temp = _mm256_madd_epi16(cvt16_src118, filter_50); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
t2_left = _mm256_madd_epi16(cvt16_src20, filter_67); | |||||
temp = _mm256_madd_epi16(cvt16_src22, filter_80); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
t2_right = _mm256_madd_epi16(cvt16_src216, filter_67); | |||||
temp = _mm256_madd_epi16(cvt16_src218, filter_80); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
sum_left = _mm256_add_epi32(t0_left, t1_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t2_left); | |||||
sum_right = _mm256_add_epi32(t0_right, t1_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t2_right); | |||||
sum_left = _mm256_add_epi32(sum_left, bias_val); | |||||
sum_right = _mm256_add_epi32(sum_right, bias_val); | |||||
if (is_quantized) { | |||||
op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0)); | |||||
} else { | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); | |||||
} | |||||
r0 += 32; | |||||
r1 += 32; | |||||
r2 += 32; | |||||
dst0 += 16; | |||||
out_ptr0 += 16; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> | |||||
void avx2_chanwise_direct_stride2_5x5_int8(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
size_t tail_step = IW - OW * 2; | |||||
int8_t* dst0 = dst; | |||||
int32_t* out_ptr0 = temp; | |||||
const int8_t* r0 = src; | |||||
const int8_t* r1 = src + IW; | |||||
const int8_t* r2 = src + 2 * IW; | |||||
const int8_t* r3 = src + 3 * IW; | |||||
const int8_t* r4 = src + 4 * IW; | |||||
uint8_t fill_zero = 0; | |||||
UNROLL_CALL0(25, load_filter) | |||||
__m128i k_fill = _mm_set1_epi8(fill_zero); | |||||
__m128i k01 = _mm_unpacklo_epi8(k_0, k_1); | |||||
__m128i k23 = _mm_unpacklo_epi8(k_2, k_3); | |||||
__m128i k40 = _mm_unpacklo_epi8(k_4, k_fill); | |||||
__m128i k56 = _mm_unpacklo_epi8(k_5, k_6); | |||||
__m128i k78 = _mm_unpacklo_epi8(k_7, k_8); | |||||
__m128i k90 = _mm_unpacklo_epi8(k_9, k_fill); | |||||
__m128i k1011 = _mm_unpacklo_epi8(k_10, k_11); | |||||
__m128i k1213 = _mm_unpacklo_epi8(k_12, k_13); | |||||
__m128i k140 = _mm_unpacklo_epi8(k_14, k_fill); | |||||
__m128i k1516 = _mm_unpacklo_epi8(k_15, k_16); | |||||
__m128i k1718 = _mm_unpacklo_epi8(k_17, k_18); | |||||
__m128i k190 = _mm_unpacklo_epi8(k_19, k_fill); | |||||
__m128i k2021 = _mm_unpacklo_epi8(k_20, k_21); | |||||
__m128i k2223 = _mm_unpacklo_epi8(k_22, k_23); | |||||
__m128i k240 = _mm_unpacklo_epi8(k_24, k_fill); | |||||
__m256i bias_val; | |||||
//! load bias | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias_val = _mm256_set1_epi32(*(bias)); | |||||
} else { | |||||
bias_val = _mm256_set1_epi32(0); | |||||
} | |||||
//! cvt i8 --> i16 | |||||
__m256i filter_01 = _mm256_cvtepi8_epi16(k01); | |||||
__m256i filter_23 = _mm256_cvtepi8_epi16(k23); | |||||
__m256i filter_40 = _mm256_cvtepi8_epi16(k40); | |||||
__m256i filter_56 = _mm256_cvtepi8_epi16(k56); | |||||
__m256i filter_78 = _mm256_cvtepi8_epi16(k78); | |||||
__m256i filter_90 = _mm256_cvtepi8_epi16(k90); | |||||
__m256i filter_1011 = _mm256_cvtepi8_epi16(k1011); | |||||
__m256i filter_1213 = _mm256_cvtepi8_epi16(k1213); | |||||
__m256i filter_140 = _mm256_cvtepi8_epi16(k140); | |||||
__m256i filter_1516 = _mm256_cvtepi8_epi16(k1516); | |||||
__m256i filter_1718 = _mm256_cvtepi8_epi16(k1718); | |||||
__m256i filter_190 = _mm256_cvtepi8_epi16(k190); | |||||
__m256i filter_2021 = _mm256_cvtepi8_epi16(k2021); | |||||
__m256i filter_2223 = _mm256_cvtepi8_epi16(k2223); | |||||
__m256i filter_240 = _mm256_cvtepi8_epi16(k240); | |||||
size_t width = OW >> 4; | |||||
for (size_t h = 0; h < OH; h++) { | |||||
for (size_t w = 0; w < width; w++) { | |||||
UNROLL_CALL0(5, load_src0) | |||||
UNROLL_CALL0(5, load_src2) | |||||
UNROLL_CALL0(5, load_src4) | |||||
UNROLL_CALL0(5, load_src16) | |||||
UNROLL_CALL0(5, load_src18) | |||||
UNROLL_CALL0(5, load_src20) | |||||
__m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, | |||||
t2_right, t3_left, t3_right, t4_left, t4_right, sum_left, | |||||
sum_right; | |||||
t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src02, filter_23); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src04, filter_40); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src018, filter_23); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src020, filter_40); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
t1_left = _mm256_madd_epi16(cvt16_src10, filter_56); | |||||
temp = _mm256_madd_epi16(cvt16_src12, filter_78); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src14, filter_90); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
t1_right = _mm256_madd_epi16(cvt16_src116, filter_56); | |||||
temp = _mm256_madd_epi16(cvt16_src118, filter_78); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src120, filter_90); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
t2_left = _mm256_madd_epi16(cvt16_src20, filter_1011); | |||||
temp = _mm256_madd_epi16(cvt16_src22, filter_1213); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src24, filter_140); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
t2_right = _mm256_madd_epi16(cvt16_src216, filter_1011); | |||||
temp = _mm256_madd_epi16(cvt16_src218, filter_1213); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src220, filter_140); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
t3_left = _mm256_madd_epi16(cvt16_src30, filter_1516); | |||||
temp = _mm256_madd_epi16(cvt16_src32, filter_1718); | |||||
t3_left = _mm256_add_epi32(t3_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src34, filter_190); | |||||
t3_left = _mm256_add_epi32(t3_left, temp); | |||||
t3_right = _mm256_madd_epi16(cvt16_src316, filter_1516); | |||||
temp = _mm256_madd_epi16(cvt16_src318, filter_1718); | |||||
t3_right = _mm256_add_epi32(t3_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src320, filter_190); | |||||
t3_right = _mm256_add_epi32(t3_right, temp); | |||||
t4_left = _mm256_madd_epi16(cvt16_src40, filter_2021); | |||||
temp = _mm256_madd_epi16(cvt16_src42, filter_2223); | |||||
t4_left = _mm256_add_epi32(t4_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src44, filter_240); | |||||
t4_left = _mm256_add_epi32(t4_left, temp); | |||||
t4_right = _mm256_madd_epi16(cvt16_src416, filter_2021); | |||||
temp = _mm256_madd_epi16(cvt16_src418, filter_2223); | |||||
t4_right = _mm256_add_epi32(t4_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src420, filter_240); | |||||
t4_right = _mm256_add_epi32(t4_right, temp); | |||||
sum_left = _mm256_add_epi32(t0_left, t1_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t2_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t3_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t4_left); | |||||
sum_right = _mm256_add_epi32(t0_right, t1_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t2_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t3_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t4_right); | |||||
sum_left = _mm256_add_epi32(sum_left, bias_val); | |||||
sum_right = _mm256_add_epi32(sum_right, bias_val); | |||||
if (is_quantized) { | |||||
op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0)); | |||||
} else { | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); | |||||
} | |||||
r0 += 32; | |||||
r1 += 32; | |||||
r2 += 32; | |||||
r3 += 32; | |||||
r4 += 32; | |||||
dst0 += 16; | |||||
out_ptr0 += 16; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
} | |||||
} | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> | |||||
void avx2_chanwise_direct_stride2_7x7_int8(const int8_t* src, | |||||
const int8_t* filter, | |||||
const int32_t* bias, int32_t* temp, | |||||
int8_t* dst, const size_t IH, | |||||
const size_t IW, const size_t OH, | |||||
const size_t OW, const Op& op) { | |||||
MEGDNN_MARK_USED_VAR(IH); | |||||
size_t tail_step = IW - OW * 2; | |||||
int8_t* dst0 = dst; | |||||
int32_t* out_ptr0 = temp; | |||||
const int8_t* r0 = src; | |||||
const int8_t* r1 = src + IW; | |||||
const int8_t* r2 = src + 2 * IW; | |||||
const int8_t* r3 = src + 3 * IW; | |||||
const int8_t* r4 = src + 4 * IW; | |||||
const int8_t* r5 = src + 5 * IW; | |||||
const int8_t* r6 = src + 6 * IW; | |||||
uint8_t fill_zero = 0; | |||||
UNROLL_CALL0(49, load_filter) | |||||
__m128i k_fill = _mm_set1_epi8(fill_zero); | |||||
__m128i k01 = _mm_unpacklo_epi8(k_0, k_1); | |||||
__m128i k23 = _mm_unpacklo_epi8(k_2, k_3); | |||||
__m128i k45 = _mm_unpacklo_epi8(k_4, k_5); | |||||
__m128i k60 = _mm_unpacklo_epi8(k_6, k_fill); | |||||
__m128i k78 = _mm_unpacklo_epi8(k_7, k_8); | |||||
__m128i k910 = _mm_unpacklo_epi8(k_9, k_10); | |||||
__m128i k1112 = _mm_unpacklo_epi8(k_11, k_12); | |||||
__m128i k130 = _mm_unpacklo_epi8(k_13, k_fill); | |||||
__m128i k1415 = _mm_unpacklo_epi8(k_14, k_15); | |||||
__m128i k1617 = _mm_unpacklo_epi8(k_16, k_17); | |||||
__m128i k1819 = _mm_unpacklo_epi8(k_18, k_19); | |||||
__m128i k200 = _mm_unpacklo_epi8(k_20, k_fill); | |||||
__m128i k2122 = _mm_unpacklo_epi8(k_21, k_22); | |||||
__m128i k2324 = _mm_unpacklo_epi8(k_23, k_24); | |||||
__m128i k2526 = _mm_unpacklo_epi8(k_25, k_26); | |||||
__m128i k270 = _mm_unpacklo_epi8(k_27, k_fill); | |||||
__m128i k2829 = _mm_unpacklo_epi8(k_28, k_29); | |||||
__m128i k3031 = _mm_unpacklo_epi8(k_30, k_31); | |||||
__m128i k3233 = _mm_unpacklo_epi8(k_32, k_33); | |||||
__m128i k340 = _mm_unpacklo_epi8(k_34, k_fill); | |||||
__m128i k3536 = _mm_unpacklo_epi8(k_35, k_36); | |||||
__m128i k3738 = _mm_unpacklo_epi8(k_37, k_38); | |||||
__m128i k3940 = _mm_unpacklo_epi8(k_39, k_40); | |||||
__m128i k410 = _mm_unpacklo_epi8(k_41, k_fill); | |||||
__m128i k4243 = _mm_unpacklo_epi8(k_42, k_43); | |||||
__m128i k4445 = _mm_unpacklo_epi8(k_44, k_45); | |||||
__m128i k4647 = _mm_unpacklo_epi8(k_46, k_47); | |||||
__m128i k480 = _mm_unpacklo_epi8(k_48, k_fill); | |||||
__m256i bias_val; | |||||
//! load bias | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
bias_val = _mm256_set1_epi32(*(bias)); | |||||
} else { | |||||
bias_val = _mm256_set1_epi32(0); | |||||
} | |||||
//! cvt i8 --> i16 | |||||
__m256i filter_01 = _mm256_cvtepi8_epi16(k01); | |||||
__m256i filter_23 = _mm256_cvtepi8_epi16(k23); | |||||
__m256i filter_45 = _mm256_cvtepi8_epi16(k45); | |||||
__m256i filter_60 = _mm256_cvtepi8_epi16(k60); | |||||
__m256i filter_78 = _mm256_cvtepi8_epi16(k78); | |||||
__m256i filter_910 = _mm256_cvtepi8_epi16(k910); | |||||
__m256i filter_1112 = _mm256_cvtepi8_epi16(k1112); | |||||
__m256i filter_130 = _mm256_cvtepi8_epi16(k130); | |||||
__m256i filter_1415 = _mm256_cvtepi8_epi16(k1415); | |||||
__m256i filter_1617 = _mm256_cvtepi8_epi16(k1617); | |||||
__m256i filter_1819 = _mm256_cvtepi8_epi16(k1819); | |||||
__m256i filter_200 = _mm256_cvtepi8_epi16(k200); | |||||
__m256i filter_2122 = _mm256_cvtepi8_epi16(k2122); | |||||
__m256i filter_2324 = _mm256_cvtepi8_epi16(k2324); | |||||
__m256i filter_2526 = _mm256_cvtepi8_epi16(k2526); | |||||
__m256i filter_270 = _mm256_cvtepi8_epi16(k270); | |||||
__m256i filter_2829 = _mm256_cvtepi8_epi16(k2829); | |||||
__m256i filter_3031 = _mm256_cvtepi8_epi16(k3031); | |||||
__m256i filter_3233 = _mm256_cvtepi8_epi16(k3233); | |||||
__m256i filter_340 = _mm256_cvtepi8_epi16(k340); | |||||
__m256i filter_3536 = _mm256_cvtepi8_epi16(k3536); | |||||
__m256i filter_3738 = _mm256_cvtepi8_epi16(k3738); | |||||
__m256i filter_3940 = _mm256_cvtepi8_epi16(k3940); | |||||
__m256i filter_410 = _mm256_cvtepi8_epi16(k410); | |||||
__m256i filter_4243 = _mm256_cvtepi8_epi16(k4243); | |||||
__m256i filter_4445 = _mm256_cvtepi8_epi16(k4445); | |||||
__m256i filter_4647 = _mm256_cvtepi8_epi16(k4647); | |||||
__m256i filter_480 = _mm256_cvtepi8_epi16(k480); | |||||
size_t width = OW >> 4; | |||||
for (size_t h = 0; h < OH; h++) { | |||||
for (size_t w = 0; w < width; w++) { | |||||
UNROLL_CALL0(7, load_src0) | |||||
UNROLL_CALL0(7, load_src2) | |||||
UNROLL_CALL0(7, load_src4) | |||||
UNROLL_CALL0(7, load_src6) | |||||
UNROLL_CALL0(7, load_src16) | |||||
UNROLL_CALL0(7, load_src18) | |||||
UNROLL_CALL0(7, load_src20) | |||||
UNROLL_CALL0(7, load_src22) | |||||
__m256i temp, t0_left, t0_right, t1_left, t1_right, t2_left, | |||||
t2_right, t3_left, t3_right, t4_left, t4_right, sum_left, | |||||
t5_left, t5_right, t6_left, t6_right, sum_right; | |||||
t0_left = _mm256_madd_epi16(cvt16_src00, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src02, filter_23); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src04, filter_45); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src06, filter_60); | |||||
t0_left = _mm256_add_epi32(t0_left, temp); | |||||
t0_right = _mm256_madd_epi16(cvt16_src016, filter_01); | |||||
temp = _mm256_madd_epi16(cvt16_src018, filter_23); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src020, filter_45); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src022, filter_60); | |||||
t0_right = _mm256_add_epi32(t0_right, temp); | |||||
t1_left = _mm256_madd_epi16(cvt16_src10, filter_78); | |||||
temp = _mm256_madd_epi16(cvt16_src12, filter_910); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src14, filter_1112); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src16, filter_130); | |||||
t1_left = _mm256_add_epi32(t1_left, temp); | |||||
t1_right = _mm256_madd_epi16(cvt16_src116, filter_78); | |||||
temp = _mm256_madd_epi16(cvt16_src118, filter_910); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src120, filter_1112); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src122, filter_130); | |||||
t1_right = _mm256_add_epi32(t1_right, temp); | |||||
t2_left = _mm256_madd_epi16(cvt16_src20, filter_1415); | |||||
temp = _mm256_madd_epi16(cvt16_src22, filter_1617); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src24, filter_1819); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src26, filter_200); | |||||
t2_left = _mm256_add_epi32(t2_left, temp); | |||||
t2_right = _mm256_madd_epi16(cvt16_src216, filter_1415); | |||||
temp = _mm256_madd_epi16(cvt16_src218, filter_1617); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src220, filter_1819); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src222, filter_200); | |||||
t2_right = _mm256_add_epi32(t2_right, temp); | |||||
t3_left = _mm256_madd_epi16(cvt16_src30, filter_2122); | |||||
temp = _mm256_madd_epi16(cvt16_src32, filter_2324); | |||||
t3_left = _mm256_add_epi32(t3_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src34, filter_2526); | |||||
t3_left = _mm256_add_epi32(t3_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src36, filter_270); | |||||
t3_left = _mm256_add_epi32(t3_left, temp); | |||||
t3_right = _mm256_madd_epi16(cvt16_src316, filter_2122); | |||||
temp = _mm256_madd_epi16(cvt16_src318, filter_2324); | |||||
t3_right = _mm256_add_epi32(t3_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src320, filter_2526); | |||||
t3_right = _mm256_add_epi32(t3_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src322, filter_270); | |||||
t3_right = _mm256_add_epi32(t3_right, temp); | |||||
t4_left = _mm256_madd_epi16(cvt16_src40, filter_2829); | |||||
temp = _mm256_madd_epi16(cvt16_src42, filter_3031); | |||||
t4_left = _mm256_add_epi32(t4_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src44, filter_3233); | |||||
t4_left = _mm256_add_epi32(t4_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src46, filter_340); | |||||
t4_left = _mm256_add_epi32(t4_left, temp); | |||||
t4_right = _mm256_madd_epi16(cvt16_src416, filter_2829); | |||||
temp = _mm256_madd_epi16(cvt16_src418, filter_3031); | |||||
t4_right = _mm256_add_epi32(t4_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src420, filter_3233); | |||||
t4_right = _mm256_add_epi32(t4_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src422, filter_340); | |||||
t4_right = _mm256_add_epi32(t4_right, temp); | |||||
t5_left = _mm256_madd_epi16(cvt16_src50, filter_3536); | |||||
temp = _mm256_madd_epi16(cvt16_src52, filter_3738); | |||||
t5_left = _mm256_add_epi32(t5_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src54, filter_3940); | |||||
t5_left = _mm256_add_epi32(t5_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src56, filter_410); | |||||
t5_left = _mm256_add_epi32(t5_left, temp); | |||||
t5_right = _mm256_madd_epi16(cvt16_src516, filter_3536); | |||||
temp = _mm256_madd_epi16(cvt16_src518, filter_3738); | |||||
t5_right = _mm256_add_epi32(t5_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src520, filter_3940); | |||||
t5_right = _mm256_add_epi32(t5_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src522, filter_410); | |||||
t5_right = _mm256_add_epi32(t5_right, temp); | |||||
t6_left = _mm256_madd_epi16(cvt16_src60, filter_4243); | |||||
temp = _mm256_madd_epi16(cvt16_src62, filter_4445); | |||||
t6_left = _mm256_add_epi32(t6_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src64, filter_4647); | |||||
t6_left = _mm256_add_epi32(t6_left, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src66, filter_480); | |||||
t6_left = _mm256_add_epi32(t6_left, temp); | |||||
t6_right = _mm256_madd_epi16(cvt16_src616, filter_4243); | |||||
temp = _mm256_madd_epi16(cvt16_src618, filter_4445); | |||||
t6_right = _mm256_add_epi32(t6_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src620, filter_4647); | |||||
t6_right = _mm256_add_epi32(t6_right, temp); | |||||
temp = _mm256_madd_epi16(cvt16_src622, filter_480); | |||||
t6_right = _mm256_add_epi32(t6_right, temp); | |||||
sum_left = _mm256_add_epi32(t0_left, t1_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t2_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t3_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t4_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t5_left); | |||||
sum_left = _mm256_add_epi32(sum_left, t6_left); | |||||
sum_right = _mm256_add_epi32(t0_right, t1_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t2_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t3_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t4_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t5_right); | |||||
sum_right = _mm256_add_epi32(sum_right, t6_right); | |||||
sum_left = _mm256_add_epi32(sum_left, bias_val); | |||||
sum_right = _mm256_add_epi32(sum_right, bias_val); | |||||
if (is_quantized) { | |||||
op({{sum_left, sum_right}}, reinterpret_cast<dt_qint8*>(dst0)); | |||||
} else { | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0), sum_left); | |||||
_mm256_storeu_si256((__m256i*)(out_ptr0 + 8), sum_right); | |||||
} | |||||
r0 += 32; | |||||
r1 += 32; | |||||
r2 += 32; | |||||
r3 += 32; | |||||
r4 += 32; | |||||
r5 += 32; | |||||
r6 += 32; | |||||
dst0 += 16; | |||||
out_ptr0 += 16; | |||||
} | |||||
r0 += tail_step + IW; | |||||
r1 += tail_step + IW; | |||||
r2 += tail_step + IW; | |||||
r3 += tail_step + IW; | |||||
r4 += tail_step + IW; | |||||
r5 += tail_step + IW; | |||||
r6 += tail_step + IW; | |||||
} | |||||
} | |||||
#define INSTANTIATION(stride, i, bias, is_quantized, Op) \ | |||||
template void avx2_chanwise_direct_##stride##_##i##x##i##_int8< \ | |||||
bias, is_quantized, Op>(const int8_t*, const int8_t*, \ | |||||
const int32_t*, int32_t*, int8_t*, \ | |||||
const size_t, const size_t, const size_t, \ | |||||
const size_t, const Op&); | |||||
#define FOR_OP(stride, i, is_quantized, bias) \ | |||||
INSTANTIATION(stride, i, bias, is_quantized, \ | |||||
TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \ | |||||
dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, is_quantized, \ | |||||
ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \ | |||||
dt_qint8>) \ | |||||
INSTANTIATION(stride, i, bias, is_quantized, \ | |||||
HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 MEGDNN_COMMA \ | |||||
dt_qint8>) | |||||
#define FOR_BIAS(stride, i, is_quantized) \ | |||||
FOR_OP(stride, i, is_quantized, BiasMode::NO_BIAS) \ | |||||
FOR_OP(stride, i, is_quantized, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
#define FOR_QUANTIZED(stride, i) \ | |||||
FOR_BIAS(stride, i, true) \ | |||||
FOR_BIAS(stride, i, false) | |||||
#define FOR_FILTER(stride) \ | |||||
FOR_QUANTIZED(stride, 2) \ | |||||
FOR_QUANTIZED(stride, 3) \ | |||||
FOR_QUANTIZED(stride, 5) \ | |||||
FOR_QUANTIZED(stride, 7) | |||||
#define FOR_STRIDE FOR_FILTER(stride2) | |||||
FOR_STRIDE | |||||
#undef FOR_STRIDE | |||||
#undef FOR_FILTER | |||||
#undef FOR_QUANTIZED | |||||
#undef FOR_BIAS | |||||
#undef FOR_OP | |||||
#undef INSTANTIATION | |||||
} // namespace avx2_chanwise_stride2 | |||||
#undef load_filter | |||||
#undef load_src0 | |||||
#undef load_src1 | |||||
#undef load_src2 | |||||
#undef load_src3 | |||||
#undef load_src4 | |||||
#undef load_src5 | |||||
#undef load_src6 | |||||
#undef load_src16 | |||||
#undef load_src18 | |||||
#undef load_src20 | |||||
#undef load_src22 | |||||
} // namespace x86 | } // namespace x86 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -33,6 +33,25 @@ KERN(stride1, 7) | |||||
#undef KERN | #undef KERN | ||||
} // namespace avx2_chanwise_stride1 | } // namespace avx2_chanwise_stride1 | ||||
namespace avx2_chanwise_stride2 { | |||||
#define KERN(stride, i) \ | |||||
template <BiasMode bias_mode, bool is_quantized, typename Op> \ | |||||
MEGDNN_ATTRIBUTE_TARGET("avx2") \ | |||||
void avx2_chanwise_direct_##stride##_##i##x##i##_int8( \ | |||||
const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, \ | |||||
const size_t OH, const size_t OW, const Op& op); | |||||
KERN(stride2, 2) | |||||
KERN(stride2, 3) | |||||
KERN(stride2, 5) | |||||
KERN(stride2, 7) | |||||
#undef KERN | |||||
} // namespace avx2_chanwise_stride2 | |||||
} // namespace x86 | } // namespace x86 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -18,57 +18,6 @@ namespace megdnn { | |||||
namespace x86 { | namespace x86 { | ||||
namespace avx2_chanwise_stride1 { | namespace avx2_chanwise_stride1 { | ||||
bool need_dst_copy(const NCBKernSizeParam& param) { | |||||
return param.osz[1] % 16; | |||||
} | |||||
bool need_src_copy(const NCBKernSizeParam& param) { | |||||
auto&& fm = param.filter_meta; | |||||
return (fm.padding[0] != 0 || fm.padding[1] != 0) ? true | |||||
: need_dst_copy(param); | |||||
} | |||||
void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, | |||||
size_t& OH2, size_t& OW2) { | |||||
auto&& fm = param.filter_meta; | |||||
auto SW = fm.stride[1]; | |||||
auto OH = param.osz[0]; | |||||
auto OW = param.osz[1]; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
OH2 = OH; | |||||
OW2 = (OW + 15) & ~15; | |||||
IH2 = SW * OH + FH - SW; | |||||
IW2 = SW * OW2 + FW - SW; | |||||
} | |||||
void copy_padding_kern(WorkspaceBundle bundle, | |||||
const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
bool need_src_copy_var = need_src_copy(kern_param); | |||||
size_t padding_group_size = IH2 * IW2; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1], | |||||
channel_id = ncb_index.ndrange_id[2]; | |||||
size_t workspace_group_id = ncb_index.thread_id; | |||||
const int8_t* sptr = kern_param.src<int8_t>(batch_id, group_id, channel_id); | |||||
if (need_src_copy_var) { | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size; | |||||
std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); | |||||
rep(ih, IH) { | |||||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||||
sizeof(int8_t) * IW); | |||||
} | |||||
} | |||||
}; | |||||
template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op> | template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op> | ||||
void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, | void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
@@ -97,8 +46,7 @@ void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, | |||||
batch_id = ncb_index.ndrange_id[1]; | batch_id = ncb_index.ndrange_id[1]; | ||||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | ||||
const int8_t* fptr = | |||||
kern_param.filter<dt_int8>(group_id); | |||||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
void* dst = kern_param.dst<void>(batch_id, group_id); | void* dst = kern_param.dst<void>(batch_id, group_id); | ||||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | ||||
if (need_src_copy_var) { | if (need_src_copy_var) { | ||||
@@ -130,9 +78,9 @@ void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, | |||||
if (need_post_process) { | if (need_post_process) { | ||||
tptr = static_cast<int32_t*>(bundle.get(2)) + | tptr = static_cast<int32_t*>(bundle.get(2)) + | ||||
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); | ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); | ||||
DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) | |||||
DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) | |||||
} else { | } else { | ||||
DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) | |||||
DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) | |||||
} | } | ||||
#undef KERN_NEED_POST_PROCESS | #undef KERN_NEED_POST_PROCESS | ||||
@@ -11,27 +11,15 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/x86/conv_bias/int8/common_helper.h" | |||||
#include "src/x86/conv_bias/opr_impl.h" | #include "src/x86/conv_bias/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace x86 { | namespace x86 { | ||||
namespace avx2_chanwise_stride1 { | namespace avx2_chanwise_stride1 { | ||||
using NCBKern = fallback::ConvBiasImpl::NCBKern; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
using conv_fun = std::function<void(WorkspaceBundle bundle, | using conv_fun = std::function<void(WorkspaceBundle bundle, | ||||
const NCBKernParam& kern_param, | const NCBKernParam& kern_param, | ||||
const NCBKernIndex& ncb_index)>; | const NCBKernIndex& ncb_index)>; | ||||
bool need_dst_copy(const NCBKernSizeParam& param); | |||||
bool need_src_copy(const NCBKernSizeParam& param); | |||||
void get_rectified_size(const NCBKernSizeParam& param, size_t& IH2, size_t& IW2, | |||||
size_t& OH2, size_t& OW2); | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param, | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param, | ||||
WorkspaceBundle bundle); | WorkspaceBundle bundle); | ||||
@@ -0,0 +1,204 @@ | |||||
/** | |||||
* \file src/x86/conv_bias/int8/avx2_chanwsie_stride2.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/x86/conv_bias/int8/avx2_chanwise_stride2.h" | |||||
#include "src/x86/conv_bias/int8/avx2_chanwise_kern.h" | |||||
#include "src/x86/elemwise_op.h" | |||||
namespace megdnn { | |||||
namespace x86 { | |||||
namespace avx2_chanwise_stride2 { | |||||
template <size_t filter, BiasMode bias_mode, bool is_quantized, typename Op> | |||||
void conv_kimpl(WorkspaceBundle bundle, const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
bool need_src_copy_var = need_src_copy(kern_param); | |||||
bool need_dst_copy_var = need_dst_copy(kern_param); | |||||
bool need_post_process = | |||||
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
Op op = Op(1.0f, 4.0f); | |||||
if (need_post_process) { | |||||
float scale_bias = | |||||
kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
op = Op(scale_bias, scale_dst); | |||||
} | |||||
size_t padding_group_size = IH2 * IW2; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
size_t workspace_group_id = ncb_index.thread_id; | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1]; | |||||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
void* dst = kern_param.dst<void>(batch_id, group_id); | |||||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||||
if (need_src_copy_var) { | |||||
sptr = static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size; | |||||
} | |||||
void* dptr = nullptr; | |||||
int32_t* tptr = nullptr; | |||||
if (need_dst_copy_var) { | |||||
dptr = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>(bundle.get(1)) + | |||||
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size()); | |||||
} else { | |||||
dptr = dst; | |||||
} | |||||
#define KERN_NEED_POST_PROCESS(filter) \ | |||||
avx2_chanwise_direct_stride2_##filter##x##filter##_int8<bias_mode, true, \ | |||||
Op>( \ | |||||
sptr, fptr, bptr, tptr, static_cast<int8_t*>(dptr), IH2, IW2, OH2, \ | |||||
OW2, op) | |||||
#define KERN_NO_POST_PROCESS(filter) \ | |||||
avx2_chanwise_direct_stride2_##filter##x##filter##_int8<bias_mode, false, \ | |||||
Op>( \ | |||||
sptr, fptr, bptr, static_cast<int32_t*>(dptr), nullptr, IH2, IW2, \ | |||||
OH2, OW2, op) | |||||
if (need_post_process) { | |||||
tptr = static_cast<int32_t*>(bundle.get(2)) + | |||||
ncb_index.thread_id * OH2 * OW2 * kern_param.dst_type.size(); | |||||
DISPATCH_FILTER(filter, KERN_NEED_POST_PROCESS) | |||||
} else { | |||||
DISPATCH_FILTER(filter, KERN_NO_POST_PROCESS) | |||||
} | |||||
#undef KERN_NEED_POST_PROCESS | |||||
#undef KERN_NO_POST_PROCESS | |||||
if (need_dst_copy_var) { | |||||
rep(oh, OH) { | |||||
std::memcpy(reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>(dst) + | |||||
oh * OW * kern_param.dst_type.size()), | |||||
reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>(dptr) + | |||||
oh * OW2 * kern_param.dst_type.size()), | |||||
kern_param.dst_type.size() * OW); | |||||
} | |||||
} | |||||
}; | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& kern_param, | |||||
WorkspaceBundle bundle) { | |||||
MEGDNN_MARK_USED_VAR(kern_param); | |||||
auto fm = kern_param.filter_meta; | |||||
size_t group = fm.group; | |||||
size_t n = kern_param.n; | |||||
SmallVector<NCBKern> ncb_kerns; | |||||
conv_fun do_conv_fun = nullptr; | |||||
#define DO_CONV_KERN_FUN(filter, bias_mode, is_quantized, op) \ | |||||
do_conv_fun = conv_kimpl<filter, bias_mode, is_quantized, op>; | |||||
#define GET_OP_PARAM(i, bias_mode, is_quantized) \ | |||||
switch (kern_param.nonlineMode) { \ | |||||
case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ | |||||
TypeCvtOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \ | |||||
MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::RELU: \ | |||||
DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ | |||||
ReluOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \ | |||||
MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
DO_CONV_KERN_FUN(i, bias_mode, is_quantized, \ | |||||
HSwishOp<SIMDType::AVX2 MEGDNN_COMMA dt_qint32 \ | |||||
MEGDNN_COMMA dt_qint8>) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "do not support nonlineMode: %d", \ | |||||
static_cast<int>(kern_param.nonlineMode)); \ | |||||
break; \ | |||||
} | |||||
#define GET_BIAS_MODE_PARAM(i, is_quantized) \ | |||||
switch (kern_param.bias_mode) { \ | |||||
case BiasMode::NO_BIAS: \ | |||||
GET_OP_PARAM(i, BiasMode::NO_BIAS, is_quantized) \ | |||||
break; \ | |||||
case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS, is_quantized) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "do not support bias mode: %d", \ | |||||
static_cast<int>(kern_param.bias_mode)); \ | |||||
break; \ | |||||
} | |||||
#define GET_QUANTIZED(i) \ | |||||
switch (kern_param.dst_type.enumv()) { \ | |||||
case DTypeEnum::QuantizedS8: \ | |||||
GET_BIAS_MODE_PARAM(i, true) \ | |||||
break; \ | |||||
case DTypeEnum::QuantizedS32: \ | |||||
GET_BIAS_MODE_PARAM(i, false) \ | |||||
break; \ | |||||
case DTypeEnum::Int32: \ | |||||
GET_BIAS_MODE_PARAM(i, false) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert(0, "do not support dtype: %d", \ | |||||
static_cast<int>(kern_param.dst_type.enumv())); \ | |||||
break; \ | |||||
} | |||||
#define DISPATCH_CONV_KERN() \ | |||||
switch (kern_param.filter_meta.spatial[0]) { \ | |||||
case 2: \ | |||||
GET_QUANTIZED(2) \ | |||||
break; \ | |||||
case 3: \ | |||||
GET_QUANTIZED(3) \ | |||||
break; \ | |||||
case 5: \ | |||||
GET_QUANTIZED(5) \ | |||||
break; \ | |||||
case 7: \ | |||||
GET_QUANTIZED(7) \ | |||||
break; \ | |||||
default: \ | |||||
megdnn_assert( \ | |||||
0, "do not support kernel: %d", \ | |||||
static_cast<int>(kern_param.filter_meta.spatial[0])); \ | |||||
break; \ | |||||
} | |||||
DISPATCH_CONV_KERN(); | |||||
auto exec_one_group = [bundle, do_conv_fun](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
copy_padding_kern(bundle, kern_param, ncb_index); | |||||
do_conv_fun(bundle, kern_param, ncb_index); | |||||
}; | |||||
ncb_kerns.push_back({exec_one_group, {group, n, 1_z}}); | |||||
return ncb_kerns; | |||||
} | |||||
} // namespace avx2_chanwise_stride2 | |||||
} // namespace x86 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* \file src/x86/conv_bias/int8/avx2_chanwsie_stride2.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/x86/conv_bias/int8/common_helper.h" | |||||
#include "src/x86/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace x86 { | |||||
namespace avx2_chanwise_stride2 { | |||||
using conv_fun = std::function<void(WorkspaceBundle bundle, | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index)>; | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param, | |||||
WorkspaceBundle bundle); | |||||
} // namespace avx2_chanwise_stride2 | |||||
} // namespace x86 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,83 @@ | |||||
/** | |||||
* \file dnn/src/x86/conv_bias/int8/chainwise_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/arch.h" | |||||
#include "src/x86/conv_bias/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace x86 { | |||||
using NCBKern = fallback::ConvBiasImpl::NCBKern; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
static inline bool need_dst_copy(const NCBKernSizeParam& param) { | |||||
return param.osz[1] % 16; | |||||
} | |||||
static inline bool need_src_copy(const NCBKernSizeParam& param) { | |||||
auto&& fm = param.filter_meta; | |||||
return (fm.padding[0] != 0 || fm.padding[1] != 0) ? true | |||||
: need_dst_copy(param); | |||||
} | |||||
static inline void get_rectified_size(const NCBKernSizeParam& param, | |||||
size_t& IH2, size_t& IW2, size_t& OH2, | |||||
size_t& OW2) { | |||||
auto&& fm = param.filter_meta; | |||||
auto SW = fm.stride[1]; | |||||
auto OH = param.osz[0]; | |||||
auto OW = param.osz[1]; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
OH2 = OH; | |||||
OW2 = (OW + 15) & ~15; | |||||
IH2 = SW * OH + FH - SW; | |||||
IW2 = SW * OW2 + FW - SW; | |||||
} | |||||
static inline void copy_padding_kern( | |||||
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t IH2, IW2, OH2, OW2; | |||||
get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
bool need_src_copy_var = need_src_copy(kern_param); | |||||
size_t padding_group_size = IH2 * IW2; | |||||
bundle.set(kern_param.workspace_ptr); | |||||
size_t group_id = ncb_index.ndrange_id[0], | |||||
batch_id = ncb_index.ndrange_id[1], | |||||
channel_id = ncb_index.ndrange_id[2]; | |||||
size_t workspace_group_id = ncb_index.thread_id; | |||||
const int8_t* sptr = kern_param.src<int8_t>(batch_id, group_id, channel_id); | |||||
if (need_src_copy_var) { | |||||
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
workspace_group_id * padding_group_size; | |||||
std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); | |||||
rep(ih, std::min(IH, IH2)) { | |||||
std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, | |||||
sizeof(int8_t) * IW); | |||||
} | |||||
} | |||||
}; | |||||
} // namespace x86 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,13 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <immintrin.h> | #include <immintrin.h> | ||||
#include "src/common/unroll_macro.h" | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
#include "src/common/unroll_macro.h" | |||||
#include "src/x86/conv_bias/int8/chanwise_helper.h" | |||||
#ifdef WIN32 | #ifdef WIN32 | ||||
#include <smmintrin.h> | #include <smmintrin.h> | ||||
#endif | #endif | ||||
@@ -6,17 +6,18 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/x86/conv_bias/opr_impl.h" | #include "src/x86/conv_bias/opr_impl.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <memory> | #include <memory> | ||||
#include "src/x86/matrix_mul/opr_impl.h" | |||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/x86/conv_bias/f32/algos.h" | #include "src/x86/conv_bias/f32/algos.h" | ||||
#include "src/x86/conv_bias/int8/algos.h" | #include "src/x86/conv_bias/int8/algos.h" | ||||
#include "src/x86/matrix_mul/opr_impl.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace x86; | using namespace x86; | ||||
@@ -69,6 +70,10 @@ void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const { | |||||
return x86_algo_type; | return x86_algo_type; | ||||
} | } | ||||
void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoDirect stride1_direct_large_group{true}; | AlgoDirect stride1_direct_large_group{true}; | ||||
AlgoDirect stride1_direct_small_group{false}; | AlgoDirect stride1_direct_small_group{false}; | ||||
@@ -77,6 +82,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; | AlgoDirectAvx2Stride1Int8 avx2_stride1_direct_int8; | ||||
AlgoAVX2DirectConvStride2 avx2_stride2_direct; | AlgoAVX2DirectConvStride2 avx2_stride2_direct; | ||||
AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; | AlgoChanWiseAvx2Stride1Qint8 avx2_stride1_chanwsie_qint8; | ||||
AlgoChanWiseAvx2Stride2Qint8 avx2_stride2_chanwsie_qint8; | |||||
AlgoMatrixMul matmul; | AlgoMatrixMul matmul; | ||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8; | AlgoMkldnnMatmulQint8 mkldnn_matmul_qint8; | ||||
@@ -85,6 +91,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoMkldnnConv mkldnn_conv_fp32; | AlgoMkldnnConv mkldnn_conv_fp32; | ||||
#endif | #endif | ||||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
@@ -100,6 +107,7 @@ public: | |||||
all_algos.emplace_back(&avx2_stride1_direct_int8); | all_algos.emplace_back(&avx2_stride1_direct_int8); | ||||
all_algos.emplace_back(&avx2_stride2_direct); | all_algos.emplace_back(&avx2_stride2_direct); | ||||
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | ||||
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); | |||||
all_algos.emplace_back(&matmul); | all_algos.emplace_back(&matmul); | ||||
static CpuOprDelegationStorage<> storage; | static CpuOprDelegationStorage<> storage; | ||||
@@ -107,7 +115,8 @@ public: | |||||
auto&& matmul_algos = | auto&& matmul_algos = | ||||
static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack(); | static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack(); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) continue; | |||||
if (algo->type() == nullptr) | |||||
continue; | |||||
for (uint32_t tile_size : {8, 16, 24}) { | for (uint32_t tile_size : {8, 16, 24}) { | ||||
refhold.emplace_back(new AlgoFP32WinogradF63_8x8( | refhold.emplace_back(new AlgoFP32WinogradF63_8x8( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -32,6 +33,7 @@ public: | |||||
class AlgoDirectAvx2Stride1Int8; | class AlgoDirectAvx2Stride1Int8; | ||||
class AlgoAVX2DirectConvStride2; | class AlgoAVX2DirectConvStride2; | ||||
class AlgoChanWiseAvx2Stride1Qint8; | class AlgoChanWiseAvx2Stride1Qint8; | ||||
class AlgoChanWiseAvx2Stride2Qint8; | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
class AlgoMkldnnConv; | class AlgoMkldnnConv; | ||||
class AlgoMkldnnQint8; | class AlgoMkldnnQint8; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/x86/utils.h" | #include "src/x86/utils.h" | ||||
#include "test/x86/fixture.h" | #include "test/x86/fixture.h" | ||||
@@ -41,7 +42,8 @@ TEST_F(X86, CONV_BIAS_FORWARD) { | |||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { | |||||
static void avx2_chanwise_direct_int8x8x32(Handle* handle, uint32_t stride, | |||||
const char* algo) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -50,8 +52,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | if (w + 2 * p < kernel || h + 2 * p < kernel) | ||||
return; | return; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = p; | param.pad_h = p; | ||||
param.pad_w = p; | param.pad_w = p; | ||||
param.nonlineMode = nonline_mode; | param.nonlineMode = nonline_mode; | ||||
@@ -74,7 +76,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { | |||||
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) | for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) | ||||
run(ic, w, h, kernel, pad, nonline_mode); | run(ic, w, h, kernel, pad, nonline_mode); | ||||
Checker<ConvBias> checker(handle()); | |||||
Checker<ConvBias> checker(handle); | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
checker.set_dtype(0, dtype::Int8()) | checker.set_dtype(0, dtype::Int8()) | ||||
.set_dtype(1, dtype::Int8()) | .set_dtype(1, dtype::Int8()) | ||||
@@ -85,15 +87,25 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { | |||||
.set_rng(2, &rng) | .set_rng(2, &rng) | ||||
.set_epsilon(1e-3); | .set_epsilon(1e-3); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).exec( | checker.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_INT8x8x32) { | |||||
avx2_chanwise_direct_int8x8x32(handle(), 1, | |||||
"X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); | |||||
} | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_INT8x8x32) { | |||||
avx2_chanwise_direct_int8x8x32(handle(), 2, | |||||
"X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); | |||||
} | |||||
static void avx2_chanwise_direct_quantizeds32(Handle* handle, uint32_t stride, | |||||
const char* algo) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -102,8 +114,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | if (w + 2 * p < kernel || h + 2 * p < kernel) | ||||
return; | return; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = p; | param.pad_h = p; | ||||
param.pad_w = p; | param.pad_w = p; | ||||
param.nonlineMode = nonline_mode; | param.nonlineMode = nonline_mode; | ||||
@@ -126,7 +138,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { | |||||
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) | for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) | ||||
run(ic, w, h, kernel, pad, nonline_mode); | run(ic, w, h, kernel, pad, nonline_mode); | ||||
Checker<ConvBias> checker(handle()); | |||||
Checker<ConvBias> checker(handle); | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
@@ -137,15 +149,26 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { | |||||
.set_rng(2, &rng) | .set_rng(2, &rng) | ||||
.set_epsilon(1e-3); | .set_epsilon(1e-3); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).exec( | checker.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS32) { | |||||
avx2_chanwise_direct_quantizeds32( | |||||
handle(), 1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); | |||||
} | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_QuantizedS32) { | |||||
avx2_chanwise_direct_quantizeds32( | |||||
handle(), 2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); | |||||
} | |||||
static void avx2_chanwise_direct_quantizeds8x8x8(Handle* handle, | |||||
uint32_t stride, | |||||
const char* algo) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -154,8 +177,8 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | if (w + 2 * p < kernel || h + 2 * p < kernel) | ||||
return; | return; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = p; | param.pad_h = p; | ||||
param.pad_w = p; | param.pad_w = p; | ||||
param.nonlineMode = nonline_mode; | param.nonlineMode = nonline_mode; | ||||
@@ -180,7 +203,7 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { | |||||
NonlineMode::RELU}) | NonlineMode::RELU}) | ||||
run(ic, w, h, kernel, pad, nonline_mode); | run(ic, w, h, kernel, pad, nonline_mode); | ||||
Checker<ConvBias> checker(handle()); | |||||
Checker<ConvBias> checker(handle); | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
@@ -191,14 +214,23 @@ TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { | |||||
.set_rng(2, &rng) | .set_rng(2, &rng) | ||||
.set_epsilon(1e-3); | .set_epsilon(1e-3); | ||||
checker.set_before_exec_callback( | checker.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).exec( | checker.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}); | {arg.src, arg.filter, arg.bias, {}, {}}); | ||||
} | } | ||||
} | } | ||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE1_QuantizedS8x8x8) { | |||||
avx2_chanwise_direct_quantizeds8x8x8( | |||||
handle(), 1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); | |||||
} | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CHANWISE_DIRECT_STRIDE2_QuantizedS8x8x8) { | |||||
avx2_chanwise_direct_quantizeds8x8x8( | |||||
handle(), 2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); | |||||
} | |||||
TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_INT8x8x32) { | TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_INT8x8x32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
@@ -343,7 +375,6 @@ TEST_F(X86_MULTI_THREADS, AVX2_CONV_BIAS_DIRECT_STRIDE1_S8S8S8) { | |||||
args.emplace_back(param, TensorShape{2, 2 * ic, h, w}, | args.emplace_back(param, TensorShape{2, 2 * ic, h, w}, | ||||
TensorShape{2, oc / 2, ic, kernel, kernel}, | TensorShape{2, oc / 2, ic, kernel, kernel}, | ||||
TensorShape{1, oc, 1, 1}); | TensorShape{1, oc, 1, 1}); | ||||
}; | }; | ||||
for (size_t kernel : {2, 3, 5, 7}) | for (size_t kernel : {2, 3, 5, 7}) | ||||
@@ -967,8 +998,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
if (x86::is_supported(x86::SIMDType::VNNI)) { | if (x86::is_supported(x86::SIMDType::VNNI)) { | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||||
"CONV1x1:X86_INT8X8X32_MKLDNN:24"); | |||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||||
"CONV1x1:X86_INT8X8X32_MKLDNN:24"); | |||||
} | } | ||||
#endif | #endif | ||||
#if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
@@ -983,8 +1014,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | ||||
"CONV1x1:X86_INT8X8X32_AVX2_4X16X2:24"); | "CONV1x1:X86_INT8X8X32_AVX2_4X16X2:24"); | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||||
"CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); | |||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||||
"CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); | |||||
} | } | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||
dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | ||||
@@ -1231,7 +1262,7 @@ TEST_F(X86_MULTI_THREADS, BENCHMARK_CONVBIAS_FP32_MKLDNN) { | |||||
#endif | #endif | ||||
/************************* Winograd ****************************/ | /************************* Winograd ****************************/ | ||||
namespace{ | |||||
namespace { | |||||
std::vector<conv_bias::TestArg> get_winograd_mk_nchw88_args() { | std::vector<conv_bias::TestArg> get_winograd_mk_nchw88_args() { | ||||
std::vector<conv_bias::TestArg> args; | std::vector<conv_bias::TestArg> args; | ||||
param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
@@ -1265,17 +1296,17 @@ std::vector<conv_bias::TestArg> get_winograd_mk_nchw88_args() { | |||||
TensorShape{2, oc, ic, 3, 3, 8, 8}, | TensorShape{2, oc, ic, 3, 3, 8, 8}, | ||||
TensorShape{1, 2 * oc, 1, 1, 8});*/ | TensorShape{1, 2 * oc, 1, 1, 8});*/ | ||||
}}} | }}} | ||||
// clang-format on | |||||
//! test for multi-thread OC parallel | |||||
cur_param.sparse = param::ConvBias::Sparse::DENSE; | |||||
cur_param.pad_h = cur_param.pad_w = 1; | |||||
args.emplace_back(cur_param, TensorShape{2, 1, 9, 9, 8}, | |||||
TensorShape{128, 1, 3, 3, 8, 8}, | |||||
TensorShape{1, 128, 1, 1, 8}); | |||||
/*cur_param.sparse = param::ConvBias::Sparse::GROUP; | |||||
args.emplace_back(cur_param, TensorShape{2, 2, 9, 9, 8}, | |||||
TensorShape{2, 128, 1, 3, 3, 8, 8}, | |||||
TensorShape{1, 2 * 128, 1, 1, 8});*/ | |||||
// clang-format on | |||||
//! test for multi-thread OC parallel | |||||
cur_param.sparse = param::ConvBias::Sparse::DENSE; | |||||
cur_param.pad_h = cur_param.pad_w = 1; | |||||
args.emplace_back(cur_param, TensorShape{2, 1, 9, 9, 8}, | |||||
TensorShape{128, 1, 3, 3, 8, 8}, | |||||
TensorShape{1, 128, 1, 1, 8}); | |||||
/*cur_param.sparse = param::ConvBias::Sparse::GROUP; | |||||
args.emplace_back(cur_param, TensorShape{2, 2, 9, 9, 8}, | |||||
TensorShape{2, 128, 1, 3, 3, 8, 8}, | |||||
TensorShape{1, 2 * 128, 1, 1, 8});*/ | |||||
} | } | ||||
return args; | return args; | ||||
} | } | ||||
@@ -1329,7 +1360,8 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_WINOGRAD_WEIGHT_PREPROCESS) { | |||||
auto conv_bias_opr = handle->create_operator<ConvBias>(); | auto conv_bias_opr = handle->create_operator<ConvBias>(); | ||||
conv_bias_opr->param() = param; | conv_bias_opr->param() = param; | ||||
conv_bias_opr->param().format = param::ConvBias::Format::NCHW88_WINOGRAD; | |||||
conv_bias_opr->param().format = | |||||
param::ConvBias::Format::NCHW88_WINOGRAD; | |||||
conv_bias_opr->param().output_block_size = m; | conv_bias_opr->param().output_block_size = m; | ||||
size_t conv_bias_workspace_in_bytes = | size_t conv_bias_workspace_in_bytes = | ||||
conv_bias_opr->get_workspace_in_bytes( | conv_bias_opr->get_workspace_in_bytes( | ||||
@@ -1720,17 +1752,16 @@ void benchmark_impl(const param::ConvBias param, | |||||
} | } | ||||
} | } | ||||
void benchmark_impl_comp(const param::ConvBias param, | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>>& | |||||
shapes_and_computation, | |||||
const std::string algo_name, const std::string algo_name1,size_t RUNS, | |||||
TaskExecutorConfig&& multi_thread_config, | |||||
TaskExecutorConfig&& single_thread_config,std::vector<DType> dtype_v) { | |||||
void benchmark_impl_comp( | |||||
const param::ConvBias param, | |||||
std::vector<std::pair<SmallVector<TensorShape>, float>>& | |||||
shapes_and_computation, | |||||
const std::string algo_name, const std::string algo_name1, size_t RUNS, | |||||
TaskExecutorConfig&& multi_thread_config, | |||||
TaskExecutorConfig&& single_thread_config, std::vector<DType> dtype_v) { | |||||
std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | ||||
dtype::Float32(), dtype::Float32()}; | dtype::Float32(), dtype::Float32()}; | ||||
std::vector<float> multi_thread_times, single_thread_times; | std::vector<float> multi_thread_times, single_thread_times; | ||||
{ | { | ||||
auto multi_thread_hanle = | auto multi_thread_hanle = | ||||
@@ -1738,10 +1769,10 @@ void benchmark_impl_comp(const param::ConvBias param, | |||||
auto benchmarker = Benchmarker<ConvBias>(multi_thread_hanle.get()); | auto benchmarker = Benchmarker<ConvBias>(multi_thread_hanle.get()); | ||||
benchmarker.set_times(RUNS) | benchmarker.set_times(RUNS) | ||||
.set_display(false) | .set_display(false) | ||||
.set_dtype(0,dtype_v[0]) | |||||
.set_dtype(1,dtype_v[1]) | |||||
.set_dtype(2,dtype_v[2]) | |||||
.set_dtype(4,dtype_v[3]) | |||||
.set_dtype(0, dtype_v[0]) | |||||
.set_dtype(1, dtype_v[1]) | |||||
.set_dtype(2, dtype_v[2]) | |||||
.set_dtype(4, dtype_v[3]) | |||||
.set_param(param) | .set_param(param) | ||||
.set_before_exec_callback( | .set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBias>( | conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
@@ -1756,10 +1787,10 @@ void benchmark_impl_comp(const param::ConvBias param, | |||||
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | ||||
benchmarker.set_times(RUNS) | benchmarker.set_times(RUNS) | ||||
.set_display(false) | .set_display(false) | ||||
.set_dtype(0,dtype_v[0]) | |||||
.set_dtype(1,dtype_v[1]) | |||||
.set_dtype(2,dtype_v[2]) | |||||
.set_dtype(4,dtype_v[3]) | |||||
.set_dtype(0, dtype_v[0]) | |||||
.set_dtype(1, dtype_v[1]) | |||||
.set_dtype(2, dtype_v[2]) | |||||
.set_dtype(4, dtype_v[3]) | |||||
.set_param(param) | .set_param(param) | ||||
.set_before_exec_callback( | .set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBias>( | conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
@@ -1789,11 +1820,13 @@ void benchmark_impl_comp(const param::ConvBias param, | |||||
} | } | ||||
} // namespace | } // namespace | ||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8) { | |||||
static void benchmark_convbias_chanwise_avx2_int8(uint32_t stride, | |||||
const char* algo) { | |||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
param.stride_h = 1; | |||||
param.stride_w = 1; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | param.sparse = param::ConvBias::Sparse::GROUP; | ||||
std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | std::vector<DType> data_type = {dtype::Int8(), dtype::Int8(), | ||||
@@ -1841,14 +1874,23 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8) { | |||||
bench_case(1, 576, 14, 14, 2); | bench_case(1, 576, 14, 14, 2); | ||||
bench_case(1, 960, 7, 7, 2); | bench_case(1, 960, 7, 7, 2); | ||||
std::string algo_name = "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; | |||||
printf("Benchmark X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1\n"); | |||||
std::string algo_name = algo; | |||||
printf("Benchmark %s\n", algo); | |||||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
{4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
{1, {4}}, data_type); | {1, {4}}, data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8_S1) { | |||||
benchmark_convbias_chanwise_avx2_int8( | |||||
1, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"); | |||||
} | |||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_CHANWISE_AVX2_INT8_S2) { | |||||
benchmark_convbias_chanwise_avx2_int8( | |||||
2, "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"); | |||||
} | |||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { | TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_AVX2_INT8) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
@@ -2129,7 +2171,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32) { | |||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) { | |||||
TEST_F(X86_BENCHMARK_MULTI_THREADS, | |||||
BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) { | |||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::ConvBias param; | param::ConvBias param; | ||||
@@ -2143,9 +2186,8 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) | |||||
dtype::Float32(), dtype::Float32()}; | dtype::Float32(), dtype::Float32()}; | ||||
std::vector<std::pair<SmallVector<TensorShape>, float>> | std::vector<std::pair<SmallVector<TensorShape>, float>> | ||||
shapes_and_computation; | shapes_and_computation; | ||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, | |||||
size_t W, size_t FS, | |||||
size_t group) { | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | |||||
size_t FS, size_t group) { | |||||
SmallVector<TensorShape> shapes{{N, IC, H, W}, | SmallVector<TensorShape> shapes{{N, IC, H, W}, | ||||
{OC / group, IC / group, FS, FS}, | {OC / group, IC / group, FS, FS}, | ||||
{1, OC, 1, 1}, | {1, OC, 1, 1}, | ||||
@@ -2167,7 +2209,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | bench_case(1, 32, 32, 100, 100, 3, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 1); | bench_case(1, 32, 32, 80, 80, 3, 1); | ||||
bench_case(1, 32, 32, 80, 80, 3, 1); | bench_case(1, 32, 32, 80, 80, 3, 1); | ||||
bench_case(1, 64, 32, 7, 7, 3, 1); | bench_case(1, 64, 32, 7, 7, 3, 1); | ||||
bench_case(1, 64, 64, 7, 7, 3, 1); | bench_case(1, 64, 64, 7, 7, 3, 1); | ||||
bench_case(1, 64, 128, 7, 7, 3, 1); | bench_case(1, 64, 128, 7, 7, 3, 1); | ||||
@@ -2192,10 +2234,10 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_F32_single_thread) | |||||
std::string algo_name = "IM2COLMATMUL:X86_F32_MKL_PACKA:192"; | std::string algo_name = "IM2COLMATMUL:X86_F32_MKL_PACKA:192"; | ||||
std::string algo_name1 = "IM2COLMATMUL:X86_F32_BLAS:192"; | std::string algo_name1 = "IM2COLMATMUL:X86_F32_BLAS:192"; | ||||
printf("Benchmark IM2COLMATMUL:X86_F32_BLAS algo\n"); | printf("Benchmark IM2COLMATMUL:X86_F32_BLAS algo\n"); | ||||
benchmark_impl_comp(param, shapes_and_computation, algo_name,algo_name1, RUNS, | |||||
{1, {4}}, {1, {4}},data_type); | |||||
benchmark_impl_comp(param, shapes_and_computation, algo_name,algo_name1, RUNS, | |||||
{1, {7}}, {1, {7}},data_type); | |||||
benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, | |||||
RUNS, {1, {4}}, {1, {4}}, data_type); | |||||
benchmark_impl_comp(param, shapes_and_computation, algo_name, algo_name1, | |||||
RUNS, {1, {7}}, {1, {7}}, data_type); | |||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
@@ -2269,7 +2311,7 @@ TEST_F(X86_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_INT8X8X32) { | |||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
namespace{ | |||||
namespace { | |||||
std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel, | std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel, | ||||
size_t pack_size) { | size_t pack_size) { | ||||
std::vector<conv_bias::TestArg> args; | std::vector<conv_bias::TestArg> args; | ||||
@@ -2290,14 +2332,14 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel, | |||||
param.pad_h = p; | param.pad_h = p; | ||||
param.pad_w = p; | param.pad_w = p; | ||||
args.push_back(conv_bias::TestArg{param, | |||||
TensorShape{1, ic/8, h, w, 8}, | |||||
TensorShape{oc/8, ic/8, kernel, kernel, 8, 8}, | |||||
{1, oc/8, 1, 1, 8}}); | |||||
args.push_back(conv_bias::TestArg{ | |||||
param, | |||||
TensorShape{1, ic / 8, h, w, 8}, | |||||
TensorShape{oc / 8, ic / 8, kernel, kernel, 8, 8}, | |||||
{1, oc / 8, 1, 1, 8}}); | |||||
}; | }; | ||||
for (size_t ic : {64, 128, 256}) { | for (size_t ic : {64, 128, 256}) { | ||||
for (size_t oc : {64,128,256}) { | |||||
for (size_t oc : {64, 128, 256}) { | |||||
pack(oc, ic, 56, 56, kernel, kernel / 2); | pack(oc, ic, 56, 56, kernel, kernel / 2); | ||||
pack(oc, ic, 14, 14, kernel, kernel / 2); | pack(oc, ic, 14, 14, kernel, kernel / 2); | ||||
pack(oc, ic, 28, 28, kernel, kernel / 2); | pack(oc, ic, 28, 28, kernel, kernel / 2); | ||||
@@ -2317,8 +2359,8 @@ std::vector<conv_bias::TestArg> get_winograd_benchmark_args(size_t kernel, | |||||
return args; | return args; | ||||
} | } | ||||
void benchmark_winograd(const char* algo_name, Handle* handle, | |||||
size_t kernel, size_t pack_size) { | |||||
void benchmark_winograd(const char* algo_name, Handle* handle, size_t kernel, | |||||
size_t pack_size) { | |||||
auto&& args = get_winograd_benchmark_args(kernel, pack_size); | auto&& args = get_winograd_benchmark_args(kernel, pack_size); | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
constexpr size_t RUN = 10; | constexpr size_t RUN = 10; | ||||
@@ -2361,7 +2403,7 @@ void benchmark_winograd(const char* algo_name, Handle* handle, | |||||
computations / used_winograd, used / used_winograd); | computations / used_winograd, used / used_winograd); | ||||
} | } | ||||
} | } | ||||
} | |||||
} // namespace | |||||
TEST_F(X86, BENCHMARK_CONVBIAS_WINOGRAD_F63_8x8) { | TEST_F(X86, BENCHMARK_CONVBIAS_WINOGRAD_F63_8x8) { | ||||
benchmark_winograd("WINOGRAD:X86_F32MK8_8X8:8:6:8", handle(), 3, 8); | benchmark_winograd("WINOGRAD:X86_F32MK8_8X8:8:6:8", handle(), 3, 8); | ||||