GitOrigin-RevId: c74193a23d
release-1.7
@@ -14,6 +14,7 @@ | |||||
#include "src/aarch64/handle.h" | #include "src/aarch64/handle.h" | ||||
#include "src/aarch64/relayout/opr_impl.h" | #include "src/aarch64/relayout/opr_impl.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace relayout; | using namespace relayout; | ||||
@@ -131,6 +132,179 @@ void trans_16x16_u8( | |||||
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); | "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); | ||||
} | } | ||||
struct Transpose4Byte { | |||||
uint32_t v; | |||||
}; | |||||
static inline void trans_8x8_u32( | |||||
const void* src, void* dst, const size_t src_step, const size_t dst_step) { | |||||
uint32_t* src_ptr = (uint32_t*)src; | |||||
uint32_t* dst_ptr = (uint32_t*)dst; | |||||
uint32x4x2_t src0 = vld1q_u32_x2(src_ptr + 0 * src_step); // A0A1A2A3 | |||||
uint32x4x2_t src1 = vld1q_u32_x2(src_ptr + 1 * src_step); // B0B1B2B3 | |||||
uint32x4x2_t src2 = vld1q_u32_x2(src_ptr + 2 * src_step); // C0C1C2C3 | |||||
uint32x4x2_t src3 = vld1q_u32_x2(src_ptr + 3 * src_step); // D0D1D2D3 | |||||
uint32x4x2_t src4 = vld1q_u32_x2(src_ptr + 4 * src_step); // E0E1E2E3 | |||||
uint32x4x2_t src5 = vld1q_u32_x2(src_ptr + 5 * src_step); // F0F1F2F3 | |||||
uint32x4x2_t src6 = vld1q_u32_x2(src_ptr + 6 * src_step); // G0G1G2G3 | |||||
uint32x4x2_t src7 = vld1q_u32_x2(src_ptr + 7 * src_step); // H0H1H2H3 | |||||
uint32x4_t ab_low = vzip1q_u32(src0.val[0], src1.val[0]); // A0B0A1B1 | |||||
uint32x4_t ab_high = vzip2q_u32(src0.val[0], src1.val[0]); // A2B2A3B3 | |||||
uint32x4_t cd_low = vzip1q_u32(src2.val[0], src3.val[0]); // C0D0C1D1 | |||||
uint32x4_t cd_high = vzip2q_u32(src2.val[0], src3.val[0]); // C2D2C3D3 | |||||
uint32x4_t ef_low = vzip1q_u32(src4.val[0], src5.val[0]); // E0F0E1F1 | |||||
uint32x4_t ef_high = vzip2q_u32(src4.val[0], src5.val[0]); // E2F2E3F3 | |||||
uint32x4_t gh_low = vzip1q_u32(src6.val[0], src7.val[0]); // G0H0G1H1 | |||||
uint32x4_t gh_high = vzip2q_u32(src6.val[0], src7.val[0]); // G2H2G3H3 | |||||
uint32x4_t abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 | |||||
uint32x4_t abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 | |||||
uint32x4_t abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ab_high), | |||||
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 | |||||
uint32x4_t abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ab_high), | |||||
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 | |||||
uint32x4_t efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 | |||||
uint32x4_t efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 | |||||
uint32x4_t efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ef_high), | |||||
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 | |||||
uint32x4_t efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ef_high), | |||||
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 | |||||
vst1q_u32(dst_ptr + 0 * dst_step, abcd_0); | |||||
vst1q_u32(dst_ptr + 0 * dst_step + 4, efgh_0); | |||||
vst1q_u32(dst_ptr + 1 * dst_step, abcd_1); | |||||
vst1q_u32(dst_ptr + 1 * dst_step + 4, efgh_1); | |||||
vst1q_u32(dst_ptr + 2 * dst_step, abcd_2); | |||||
vst1q_u32(dst_ptr + 2 * dst_step + 4, efgh_2); | |||||
vst1q_u32(dst_ptr + 3 * dst_step, abcd_3); | |||||
vst1q_u32(dst_ptr + 3 * dst_step + 4, efgh_3); | |||||
ab_low = vzip1q_u32(src0.val[1], src1.val[1]); // A0B0A1B1 | |||||
ab_high = vzip2q_u32(src0.val[1], src1.val[1]); // A2B2A3B3 | |||||
cd_low = vzip1q_u32(src2.val[1], src3.val[1]); // C0D0C1D1 | |||||
cd_high = vzip2q_u32(src2.val[1], src3.val[1]); // C2D2C3D3 | |||||
ef_low = vzip1q_u32(src4.val[1], src5.val[1]); // E0F0E1F1 | |||||
ef_high = vzip2q_u32(src4.val[1], src5.val[1]); // E2F2E3F3 | |||||
gh_low = vzip1q_u32(src6.val[1], src7.val[1]); // G0H0G1H1 | |||||
gh_high = vzip2q_u32(src6.val[1], src7.val[1]); // G2H2G3H3 | |||||
abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 | |||||
abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 | |||||
abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ab_high), | |||||
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 | |||||
abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ab_high), | |||||
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 | |||||
efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 | |||||
efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 | |||||
efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||||
vreinterpretq_u64_u32(ef_high), | |||||
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 | |||||
efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||||
vreinterpretq_u64_u32(ef_high), | |||||
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 | |||||
vst1q_u32(dst_ptr + 4 * dst_step, abcd_0); | |||||
vst1q_u32(dst_ptr + 4 * dst_step + 4, efgh_0); | |||||
vst1q_u32(dst_ptr + 5 * dst_step, abcd_1); | |||||
vst1q_u32(dst_ptr + 5 * dst_step + 4, efgh_1); | |||||
vst1q_u32(dst_ptr + 6 * dst_step, abcd_2); | |||||
vst1q_u32(dst_ptr + 6 * dst_step + 4, efgh_2); | |||||
vst1q_u32(dst_ptr + 7 * dst_step, abcd_3); | |||||
vst1q_u32(dst_ptr + 7 * dst_step + 4, efgh_3); | |||||
} | |||||
struct Transpose2Byte { | |||||
uint16_t v; | |||||
}; | |||||
static inline void trans_8x8_u16( | |||||
const void* src, void* dst, const size_t src_step, const size_t dst_step) { | |||||
uint16_t* src_ptr = (uint16_t*)src; | |||||
uint16_t* dst_ptr = (uint16_t*)dst; | |||||
uint16x8_t src0 = vld1q_u16(src_ptr + 0 * src_step); // A0A1A2A3A4A5A6A7 | |||||
uint16x8_t src1 = vld1q_u16(src_ptr + 1 * src_step); // B0B1B2B3B4B5B6B7 | |||||
uint16x8_t src2 = vld1q_u16(src_ptr + 2 * src_step); // C0C1C2C3C4C5C6C7 | |||||
uint16x8_t src3 = vld1q_u16(src_ptr + 3 * src_step); // D0D1D2D3D4D5D6D7 | |||||
uint16x8_t src4 = vld1q_u16(src_ptr + 4 * src_step); // E0E1E2E3E4E5E6E7 | |||||
uint16x8_t src5 = vld1q_u16(src_ptr + 5 * src_step); // F0F1F2F3F4F5F6F7 | |||||
uint16x8_t src6 = vld1q_u16(src_ptr + 6 * src_step); // G0G1G2G3G4G5G6G7 | |||||
uint16x8_t src7 = vld1q_u16(src_ptr + 7 * src_step); // H0H1H2H3H4H5H6H7 | |||||
uint16x8_t ab_low = vzip1q_u16(src0, src1); // A0B0A1B1A2B2A3B3 | |||||
uint16x8_t ab_high = vzip2q_u16(src0, src1); // A4B4A5B5A6B6A7B7 | |||||
uint16x8_t cd_low = vzip1q_u16(src2, src3); // C0D0C1D1C2D2C3D3 | |||||
uint16x8_t cd_high = vzip2q_u16(src2, src3); // C4D4C5D5C6D6C7D7 | |||||
uint16x8_t ef_low = vzip1q_u16(src4, src5); // E0F0E1F1E2F2E3F3 | |||||
uint16x8_t ef_high = vzip2q_u16(src4, src5); // E4F4E5F5E6F6E7F7 | |||||
uint16x8_t gh_low = vzip1q_u16(src6, src7); // G0H0G1H1G2H2G3H3 | |||||
uint16x8_t gh_high = vzip2q_u16(src6, src7); // G4H4G5H5G6H6G7H7 | |||||
uint16x8_t abcd_0 = vreinterpretq_u16_u32(vzip1q_u32( | |||||
vreinterpretq_u32_u16(ab_low), | |||||
vreinterpretq_u32_u16(cd_low))); // A0B0C0D0A1B1C1D1 | |||||
uint16x8_t abcd_2 = vreinterpretq_u16_u32(vzip2q_u32( | |||||
vreinterpretq_u32_u16(ab_low), | |||||
vreinterpretq_u32_u16(cd_low))); // A2B2C2D2A3B3C3D3 | |||||
uint16x8_t abcd_4 = vreinterpretq_u16_u32(vzip1q_u32( | |||||
vreinterpretq_u32_u16(ab_high), | |||||
vreinterpretq_u32_u16(cd_high))); // A4B4C4D4A5B5C5D5 | |||||
uint16x8_t abcd_6 = vreinterpretq_u16_u32(vzip2q_u32( | |||||
vreinterpretq_u32_u16(ab_high), | |||||
vreinterpretq_u32_u16(cd_high))); // A6B6C6D6A7B7C7D7 | |||||
uint16x8_t efgh_0 = vreinterpretq_u16_u32(vzip1q_u32( | |||||
vreinterpretq_u32_u16(ef_low), | |||||
vreinterpretq_u32_u16(gh_low))); // E0F0G0H0E1F1G1H1 | |||||
uint16x8_t efgh_2 = vreinterpretq_u16_u32(vzip2q_u32( | |||||
vreinterpretq_u32_u16(ef_low), | |||||
vreinterpretq_u32_u16(gh_low))); // E2F2G2H2E3F3G3H3 | |||||
uint16x8_t efgh_4 = vreinterpretq_u16_u32(vzip1q_u32( | |||||
vreinterpretq_u32_u16(ef_high), | |||||
vreinterpretq_u32_u16(gh_high))); // E4F4G4H4E5F5G5H5 | |||||
uint16x8_t efgh_6 = vreinterpretq_u16_u32(vzip2q_u32( | |||||
vreinterpretq_u32_u16(ef_high), | |||||
vreinterpretq_u32_u16(gh_high))); // E6F6G6H6E7F7G7H7 | |||||
uint16x8_t row_0 = vreinterpretq_u16_u64( | |||||
vzip1q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); | |||||
uint16x8_t row_1 = vreinterpretq_u16_u64( | |||||
vzip2q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); | |||||
uint16x8_t row_2 = vreinterpretq_u16_u64( | |||||
vzip1q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); | |||||
uint16x8_t row_3 = vreinterpretq_u16_u64( | |||||
vzip2q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); | |||||
uint16x8_t row_4 = vreinterpretq_u16_u64( | |||||
vzip1q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); | |||||
uint16x8_t row_5 = vreinterpretq_u16_u64( | |||||
vzip2q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); | |||||
uint16x8_t row_6 = vreinterpretq_u16_u64( | |||||
vzip1q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); | |||||
uint16x8_t row_7 = vreinterpretq_u16_u64( | |||||
vzip2q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); | |||||
vst1q_u16(dst_ptr + 0 * dst_step, row_0); | |||||
vst1q_u16(dst_ptr + 1 * dst_step, row_1); | |||||
vst1q_u16(dst_ptr + 2 * dst_step, row_2); | |||||
vst1q_u16(dst_ptr + 3 * dst_step, row_3); | |||||
vst1q_u16(dst_ptr + 4 * dst_step, row_4); | |||||
vst1q_u16(dst_ptr + 5 * dst_step, row_5); | |||||
vst1q_u16(dst_ptr + 6 * dst_step, row_6); | |||||
vst1q_u16(dst_ptr + 7 * dst_step, row_7); | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -148,6 +322,30 @@ void transpose_block<TransposeByte>( | |||||
trans_16x16_u8(src, dst, src_stride, dst_stride); | trans_16x16_u8(src, dst, src_stride, dst_stride); | ||||
} | } | ||||
template <> | |||||
struct transpose_traits<Transpose4Byte> { | |||||
static constexpr size_t block_size = 8; | |||||
}; | |||||
template <> | |||||
void transpose_block<Transpose4Byte>( | |||||
const Transpose4Byte* src, Transpose4Byte* dst, const size_t src_stride, | |||||
const size_t dst_stride) { | |||||
trans_8x8_u32(src, dst, src_stride, dst_stride); | |||||
} | |||||
template <> | |||||
struct transpose_traits<Transpose2Byte> { | |||||
static constexpr size_t block_size = 8; | |||||
}; | |||||
template <> | |||||
void transpose_block<Transpose2Byte>( | |||||
const Transpose2Byte* src, Transpose2Byte* dst, const size_t src_stride, | |||||
const size_t dst_stride) { | |||||
trans_8x8_u16(src, dst, src_stride, dst_stride); | |||||
} | |||||
} // namespace transpose_fallback | } // namespace transpose_fallback | ||||
} // namespace relayout | } // namespace relayout | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -164,16 +362,33 @@ void aarch64::RelayoutForwardImpl::exec( | |||||
fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle); | fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle); | ||||
return; | return; | ||||
} | } | ||||
relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | |||||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | ||||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | ||||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | dptr = static_cast<TransposeByte*>(dst.raw_ptr); | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | ||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); | |||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
return; | |||||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | |||||
auto sptr = static_cast<Transpose2Byte*>(src.raw_ptr), | |||||
dptr = static_cast<Transpose2Byte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | |||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
return; | |||||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | |||||
auto sptr = static_cast<Transpose4Byte*>(src.raw_ptr), | |||||
dptr = static_cast<Transpose4Byte*>(dst.raw_ptr); | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | |||||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
trans_param.stride_m)); | |||||
return; | return; | ||||
} | } | ||||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | ||||
} | } | ||||
@@ -321,6 +321,12 @@ __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { | |||||
} | } | ||||
#endif | #endif | ||||
#if !defined(vld1q_u32_x2) && (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 3)) | |||||
__ai uint32x4x2_t vld1q_u32_x2(const uint32_t* p) { | |||||
return {{vld1q_u32(p), vld1q_u32(p + 4)}}; | |||||
} | |||||
#endif | |||||
__ai int8x16_t vtranslq_s8(int8x8_t a) { | __ai int8x16_t vtranslq_s8(int8x8_t a) { | ||||
int8x16_t ret; | int8x16_t ret; | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -23,7 +23,8 @@ namespace { | |||||
//! whether current shape is [b][n][m][c] and is a transpose of contig | //! whether current shape is [b][n][m][c] and is a transpose of contig | ||||
//! [b][m][n][c] | //! [b][m][n][c] | ||||
bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||||
bool is_transpose_single( | |||||
const TensorLayout& layout, TransposeParam& p, bool allow_no_contig) { | |||||
/* | /* | ||||
* assuming contig layout is: | * assuming contig layout is: | ||||
* shape: b, m, n, c | * shape: b, m, n, c | ||||
@@ -42,8 +43,9 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||||
* | * | ||||
* if b == 1 && c == 1: | * if b == 1 && c == 1: | ||||
* shape: n, m | * shape: n, m | ||||
* stride: 1, n | |||||
* stride: 1, n(stride_m for no-contig) | |||||
*/ | */ | ||||
p.stride_m = 0; | |||||
auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; | auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; | ||||
if (layout.ndim == 4) { | if (layout.ndim == 4) { | ||||
p.batch = layout[0]; | p.batch = layout[0]; | ||||
@@ -80,7 +82,15 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||||
p.n = layout.shape[0]; | p.n = layout.shape[0]; | ||||
p.m = layout.shape[1]; | p.m = layout.shape[1]; | ||||
p.c = 1; | p.c = 1; | ||||
return strd(0, 1) && strd(1, p.n); | |||||
if (strd(0, 1) && strd(1, p.n)) { | |||||
return true; | |||||
} else if ( | |||||
strd(0, 1) && layout.stride[1] > 0 && | |||||
(size_t)(layout.stride[1]) >= p.n && allow_no_contig) { | |||||
//! stride_m used in no-contig mode, stride_m >= p.n | |||||
p.stride_m = layout.stride[1]; | |||||
return true; | |||||
} | |||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
@@ -98,15 +108,16 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& | |||||
} | } | ||||
bool relayout::is_transpose( | bool relayout::is_transpose( | ||||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p) { | |||||
if (is_contig(dst) && is_transpose_single(src, p)) { | |||||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, | |||||
bool allow_non_contig) { | |||||
if (is_contig(dst) && is_transpose_single(src, p, allow_non_contig)) { | |||||
// if the original intention is to transpose (m, n) to (n, m), | // if the original intention is to transpose (m, n) to (n, m), | ||||
// then we should use (n, m) as the contig dst and use a corrsponding | // then we should use (n, m) as the contig dst and use a corrsponding | ||||
// non-contig src with the same (n, m) shape (remember relayout is | // non-contig src with the same (n, m) shape (remember relayout is | ||||
// defined on element correspondence on the logical view) | // defined on element correspondence on the logical view) | ||||
return true; | return true; | ||||
} | } | ||||
if (is_contig(src) && is_transpose_single(dst, p)) { | |||||
if (is_contig(src) && is_transpose_single(dst, p, allow_non_contig)) { | |||||
std::swap(p.m, p.n); | std::swap(p.m, p.n); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -27,7 +27,7 @@ static inline bool is_contig(const TensorLayout& layout) { | |||||
//! [b][m][n][c] to [b][n][m][c] | //! [b][m][n][c] to [b][n][m][c] | ||||
struct TransposeParam { | struct TransposeParam { | ||||
size_t batch, m, n, c; | |||||
size_t batch, m, n, c, stride_m; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -36,7 +36,9 @@ struct TransposeParam { | |||||
* Note that \p src and \p dst should have been processed by | * Note that \p src and \p dst should have been processed by | ||||
* RelayoutForward::check_layout_and_canonize | * RelayoutForward::check_layout_and_canonize | ||||
*/ | */ | ||||
bool is_transpose(const TensorLayout& src, const TensorLayout& dst, TransposeParam& p); | |||||
bool is_transpose( | |||||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, | |||||
bool allow_non_contig = false); | |||||
namespace transpose_fallback { | namespace transpose_fallback { | ||||
@@ -105,20 +107,23 @@ void transpose_block( | |||||
* \brief transpose contiguous (batch, m, n) to (batch, n, m) | * \brief transpose contiguous (batch, m, n) to (batch, n, m) | ||||
*/ | */ | ||||
template <typename T> | template <typename T> | ||||
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { | |||||
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst, size_t stride_m = 0) { | |||||
if (stride_m == 0) { | |||||
stride_m = n; | |||||
} | |||||
auto batch_src = src; | auto batch_src = src; | ||||
auto batch_dst = dst; | auto batch_dst = dst; | ||||
constexpr size_t B = transpose_traits<T>::block_size; | constexpr size_t B = transpose_traits<T>::block_size; | ||||
auto work_block = [m, n, &batch_src, &batch_dst]( | |||||
auto work_block = [m, stride_m, &batch_src, &batch_dst]( | |||||
const size_t i, const size_t j, const size_t h, | const size_t i, const size_t j, const size_t h, | ||||
const size_t w) { | const size_t w) { | ||||
auto src = batch_src + i * n + j, dst = batch_dst + j * m + i; | |||||
auto src = batch_src + i * stride_m + j, dst = batch_dst + j * m + i; | |||||
MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { | MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { | ||||
if (h == B && w == B) { | if (h == B && w == B) { | ||||
transpose_block(src, dst, n, m); | |||||
transpose_block(src, dst, stride_m, m); | |||||
} else { | } else { | ||||
transpose_block(src, dst, n, m, h, w); | |||||
transpose_block(src, dst, stride_m, m, h, w); | |||||
} | } | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -141,7 +146,7 @@ void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { | |||||
if (i < m) { | if (i < m) { | ||||
work_row(i, m - i); | work_row(i, m - i); | ||||
} | } | ||||
batch_src += m * n; | |||||
batch_src += m * stride_m; | |||||
batch_dst += m * n; | batch_dst += m * n; | ||||
} | } | ||||
} | } | ||||
@@ -48,10 +48,12 @@ void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) { | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
void call_transpose(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { | |||||
void call_transpose( | |||||
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, | |||||
size_t stride_m) { | |||||
megdnn_assert(ch == 1); | megdnn_assert(ch == 1); | ||||
relayout::transpose_fallback::transpose<T>( | relayout::transpose_fallback::transpose<T>( | ||||
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst)); | |||||
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst), stride_m); | |||||
} | } | ||||
//! one operand contiguous, and the other non-contiguous | //! one operand contiguous, and the other non-contiguous | ||||
@@ -186,7 +188,10 @@ void transpose_cv_row( | |||||
} | } | ||||
template <typename ctype> | template <typename ctype> | ||||
void transpose_cv(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { | |||||
void transpose_cv( | |||||
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, | |||||
size_t stride_m) { | |||||
megdnn_assert(stride_m == 0); | |||||
constexpr size_t B = BLOCK_SIZE; | constexpr size_t B = BLOCK_SIZE; | ||||
auto batch_src = static_cast<ctype*>(src); | auto batch_src = static_cast<ctype*>(src); | ||||
auto batch_dst = static_cast<ctype*>(dst); | auto batch_dst = static_cast<ctype*>(dst); | ||||
@@ -237,7 +242,7 @@ void RelayoutForwardImpl::exec( | |||||
} | } | ||||
relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | |||||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | ||||
} | } | ||||
@@ -245,7 +250,7 @@ void RelayoutForwardImpl::exec_after_preprocess( | |||||
const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { | const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { | ||||
if (transpose) { | if (transpose) { | ||||
auto dsize = src.layout.dtype.size() * transpose->c; | auto dsize = src.layout.dtype.size() * transpose->c; | ||||
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*) = nullptr; | |||||
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) = nullptr; | |||||
auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr), | auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr), | ||||
dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr); | dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr); | ||||
if (dsize == 1) { | if (dsize == 1) { | ||||
@@ -293,7 +298,9 @@ void RelayoutForwardImpl::exec_after_preprocess( | |||||
if (kptr) { | if (kptr) { | ||||
auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr, | auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr, | ||||
kptr]() { kptr(t.batch, t.m, t.n, t.c, sptr, dptr); }; | |||||
kptr]() { | |||||
kptr(t.batch, t.m, t.n, t.c, sptr, dptr, t.stride_m); | |||||
}; | |||||
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | ||||
return; | return; | ||||
} else { | } else { | ||||
@@ -0,0 +1,29 @@ | |||||
/** | |||||
* \file dnn/test/aarch64/fixture.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/aarch64/fixture.h" | |||||
#include "test/common/memory_manager.h" | |||||
#include "test/common/random_state.h" | |||||
#include "test/common/utils.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
Handle* AARCH64::fallback_handle() { | |||||
if (!m_fallback_handle) { | |||||
m_fallback_handle = create_cpu_handle(1); | |||||
} | |||||
return m_fallback_handle.get(); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -19,7 +19,13 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace test { | namespace test { | ||||
class AARCH64 : public ARM_COMMON {}; | |||||
class AARCH64 : public ARM_COMMON { | |||||
public: | |||||
Handle* fallback_handle(); | |||||
private: | |||||
std::unique_ptr<Handle> m_handle, m_fallback_handle; | |||||
}; | |||||
class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; | class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; | ||||
@@ -0,0 +1,118 @@ | |||||
/** | |||||
* \file dnn/test/aarch64/relayout.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/aarch64/fixture.h" | |||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/relayout.h" | |||||
#include "test/common/rng.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace { | |||||
template <typename tag> | |||||
class AARCH64_RELAYOUT : public AARCH64 {}; | |||||
TYPED_TEST_CASE(AARCH64_RELAYOUT, relayout::test_types); | |||||
TYPED_TEST(AARCH64_RELAYOUT, run) { | |||||
relayout::run_test<TypeParam>(this->handle()); | |||||
} | |||||
} // namespace | |||||
TEST_F(AARCH64, Relayout) { | |||||
Checker<Relayout> checker(handle()); | |||||
std::vector<::megdnn::DType> dtype_vec; | |||||
dtype_vec.push_back(dtype::Float32()); | |||||
dtype_vec.push_back(dtype::Int16()); | |||||
dtype_vec.push_back(dtype::Uint16()); | |||||
dtype_vec.push_back(dtype::Int8()); | |||||
for (auto dtype : dtype_vec) { | |||||
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); | |||||
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | |||||
checker.execl({src, dst}); | |||||
} | |||||
} | |||||
TEST_F(AARCH64, RelayoutBig) { | |||||
Checker<Relayout> checker(handle()); | |||||
ConsecutiveRNG rng; | |||||
checker.set_rng(0, &rng); | |||||
int m = 512; | |||||
int n = 512; | |||||
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype::Float32()); | |||||
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype::Float32()); | |||||
checker.execl({src, dst}); | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | |||||
TEST_F(AARCH64, BENCHMARK_Relayout) { | |||||
constexpr size_t WARM_RUNS = 100; | |||||
constexpr size_t RUNS = 600; | |||||
auto dtype = dtype::Float32(); | |||||
Benchmarker<Relayout> benchmarker_relayout(handle()); | |||||
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle()); | |||||
benchmarker_relayout.set_times(WARM_RUNS); | |||||
benchmarker_fbk_relayout.set_times(WARM_RUNS); | |||||
int m = 512; | |||||
int n = 512; | |||||
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype); | |||||
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype); | |||||
TensorLayoutArray tensor_case; | |||||
tensor_case.push_back(src); | |||||
tensor_case.push_back(dst); | |||||
benchmarker_relayout.exec(tensor_case); | |||||
benchmarker_fbk_relayout.exec(tensor_case); | |||||
benchmarker_relayout.set_times(RUNS); | |||||
benchmarker_fbk_relayout.set_times(RUNS); | |||||
auto used = benchmarker_relayout.exec(tensor_case) / RUNS; | |||||
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; | |||||
float bw = 2.f * m * n * 1e-6 / used * dtype.size(); | |||||
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); | |||||
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", | |||||
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); | |||||
} | |||||
TEST_F(AARCH64, BENCHMARK_Relayout_2) { | |||||
constexpr size_t WARM_RUNS = 100; | |||||
constexpr size_t RUNS = 600; | |||||
auto dtype = dtype::Float32(); | |||||
Benchmarker<Relayout> benchmarker_relayout(handle()); | |||||
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle()); | |||||
benchmarker_relayout.set_times(WARM_RUNS); | |||||
benchmarker_fbk_relayout.set_times(WARM_RUNS); | |||||
int m = 54; | |||||
int n = 28762; | |||||
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); | |||||
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | |||||
TensorLayoutArray tensor_case; | |||||
tensor_case.push_back(src); | |||||
tensor_case.push_back(dst); | |||||
benchmarker_relayout.exec(tensor_case); | |||||
benchmarker_fbk_relayout.exec(tensor_case); | |||||
benchmarker_relayout.set_times(RUNS); | |||||
benchmarker_fbk_relayout.set_times(RUNS); | |||||
auto used = benchmarker_relayout.exec(tensor_case) / RUNS; | |||||
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; | |||||
float bw = 2.f * m * n * 1e-6 / used * dtype.size(); | |||||
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); | |||||
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", | |||||
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); | |||||
} | |||||
#endif | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -180,11 +180,11 @@ TEST(RELAYOUT, TRANSPOSE_DET) { | |||||
ASSERT_EQ(p_get.c, p.c); | ASSERT_EQ(p_get.c, p.c); | ||||
} | } | ||||
}; | }; | ||||
run({2, 3}, {1, 0}, true, {1, 2, 3, 1}); | |||||
run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5}); | |||||
run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1}); | |||||
run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5}); | |||||
run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1}); | |||||
run({2, 3}, {1, 0}, true, {1, 2, 3, 1, 0}); | |||||
run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5, 0}); | |||||
run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1, 0}); | |||||
run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5, 0}); | |||||
run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1, 0}); | |||||
run({2, 3, 5}, {2, 1, 0}, false); | run({2, 3, 5}, {2, 1, 0}, false); | ||||
run({3, 2, 3, 5}, {3, 2, 1, 0}, false); | run({3, 2, 3, 5}, {3, 2, 1, 0}, false); | ||||
} | } | ||||