GitOrigin-RevId: c74193a23d
release-1.7
@@ -14,6 +14,7 @@ | |||
#include "src/aarch64/handle.h" | |||
#include "src/aarch64/relayout/opr_impl.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
using namespace megdnn; | |||
using namespace relayout; | |||
@@ -131,6 +132,179 @@ void trans_16x16_u8( | |||
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); | |||
} | |||
struct Transpose4Byte { | |||
uint32_t v; | |||
}; | |||
static inline void trans_8x8_u32( | |||
const void* src, void* dst, const size_t src_step, const size_t dst_step) { | |||
uint32_t* src_ptr = (uint32_t*)src; | |||
uint32_t* dst_ptr = (uint32_t*)dst; | |||
uint32x4x2_t src0 = vld1q_u32_x2(src_ptr + 0 * src_step); // A0A1A2A3 | |||
uint32x4x2_t src1 = vld1q_u32_x2(src_ptr + 1 * src_step); // B0B1B2B3 | |||
uint32x4x2_t src2 = vld1q_u32_x2(src_ptr + 2 * src_step); // C0C1C2C3 | |||
uint32x4x2_t src3 = vld1q_u32_x2(src_ptr + 3 * src_step); // D0D1D2D3 | |||
uint32x4x2_t src4 = vld1q_u32_x2(src_ptr + 4 * src_step); // E0E1E2E3 | |||
uint32x4x2_t src5 = vld1q_u32_x2(src_ptr + 5 * src_step); // F0F1F2F3 | |||
uint32x4x2_t src6 = vld1q_u32_x2(src_ptr + 6 * src_step); // G0G1G2G3 | |||
uint32x4x2_t src7 = vld1q_u32_x2(src_ptr + 7 * src_step); // H0H1H2H3 | |||
uint32x4_t ab_low = vzip1q_u32(src0.val[0], src1.val[0]); // A0B0A1B1 | |||
uint32x4_t ab_high = vzip2q_u32(src0.val[0], src1.val[0]); // A2B2A3B3 | |||
uint32x4_t cd_low = vzip1q_u32(src2.val[0], src3.val[0]); // C0D0C1D1 | |||
uint32x4_t cd_high = vzip2q_u32(src2.val[0], src3.val[0]); // C2D2C3D3 | |||
uint32x4_t ef_low = vzip1q_u32(src4.val[0], src5.val[0]); // E0F0E1F1 | |||
uint32x4_t ef_high = vzip2q_u32(src4.val[0], src5.val[0]); // E2F2E3F3 | |||
uint32x4_t gh_low = vzip1q_u32(src6.val[0], src7.val[0]); // G0H0G1H1 | |||
uint32x4_t gh_high = vzip2q_u32(src6.val[0], src7.val[0]); // G2H2G3H3 | |||
uint32x4_t abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 | |||
uint32x4_t abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 | |||
uint32x4_t abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ab_high), | |||
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 | |||
uint32x4_t abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ab_high), | |||
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 | |||
uint32x4_t efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 | |||
uint32x4_t efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 | |||
uint32x4_t efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ef_high), | |||
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 | |||
uint32x4_t efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ef_high), | |||
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 | |||
vst1q_u32(dst_ptr + 0 * dst_step, abcd_0); | |||
vst1q_u32(dst_ptr + 0 * dst_step + 4, efgh_0); | |||
vst1q_u32(dst_ptr + 1 * dst_step, abcd_1); | |||
vst1q_u32(dst_ptr + 1 * dst_step + 4, efgh_1); | |||
vst1q_u32(dst_ptr + 2 * dst_step, abcd_2); | |||
vst1q_u32(dst_ptr + 2 * dst_step + 4, efgh_2); | |||
vst1q_u32(dst_ptr + 3 * dst_step, abcd_3); | |||
vst1q_u32(dst_ptr + 3 * dst_step + 4, efgh_3); | |||
ab_low = vzip1q_u32(src0.val[1], src1.val[1]); // A0B0A1B1 | |||
ab_high = vzip2q_u32(src0.val[1], src1.val[1]); // A2B2A3B3 | |||
cd_low = vzip1q_u32(src2.val[1], src3.val[1]); // C0D0C1D1 | |||
cd_high = vzip2q_u32(src2.val[1], src3.val[1]); // C2D2C3D3 | |||
ef_low = vzip1q_u32(src4.val[1], src5.val[1]); // E0F0E1F1 | |||
ef_high = vzip2q_u32(src4.val[1], src5.val[1]); // E2F2E3F3 | |||
gh_low = vzip1q_u32(src6.val[1], src7.val[1]); // G0H0G1H1 | |||
gh_high = vzip2q_u32(src6.val[1], src7.val[1]); // G2H2G3H3 | |||
abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 | |||
abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 | |||
abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ab_high), | |||
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 | |||
abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ab_high), | |||
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 | |||
efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 | |||
efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 | |||
efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( | |||
vreinterpretq_u64_u32(ef_high), | |||
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 | |||
efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( | |||
vreinterpretq_u64_u32(ef_high), | |||
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 | |||
vst1q_u32(dst_ptr + 4 * dst_step, abcd_0); | |||
vst1q_u32(dst_ptr + 4 * dst_step + 4, efgh_0); | |||
vst1q_u32(dst_ptr + 5 * dst_step, abcd_1); | |||
vst1q_u32(dst_ptr + 5 * dst_step + 4, efgh_1); | |||
vst1q_u32(dst_ptr + 6 * dst_step, abcd_2); | |||
vst1q_u32(dst_ptr + 6 * dst_step + 4, efgh_2); | |||
vst1q_u32(dst_ptr + 7 * dst_step, abcd_3); | |||
vst1q_u32(dst_ptr + 7 * dst_step + 4, efgh_3); | |||
} | |||
struct Transpose2Byte { | |||
uint16_t v; | |||
}; | |||
static inline void trans_8x8_u16( | |||
const void* src, void* dst, const size_t src_step, const size_t dst_step) { | |||
uint16_t* src_ptr = (uint16_t*)src; | |||
uint16_t* dst_ptr = (uint16_t*)dst; | |||
uint16x8_t src0 = vld1q_u16(src_ptr + 0 * src_step); // A0A1A2A3A4A5A6A7 | |||
uint16x8_t src1 = vld1q_u16(src_ptr + 1 * src_step); // B0B1B2B3B4B5B6B7 | |||
uint16x8_t src2 = vld1q_u16(src_ptr + 2 * src_step); // C0C1C2C3C4C5C6C7 | |||
uint16x8_t src3 = vld1q_u16(src_ptr + 3 * src_step); // D0D1D2D3D4D5D6D7 | |||
uint16x8_t src4 = vld1q_u16(src_ptr + 4 * src_step); // E0E1E2E3E4E5E6E7 | |||
uint16x8_t src5 = vld1q_u16(src_ptr + 5 * src_step); // F0F1F2F3F4F5F6F7 | |||
uint16x8_t src6 = vld1q_u16(src_ptr + 6 * src_step); // G0G1G2G3G4G5G6G7 | |||
uint16x8_t src7 = vld1q_u16(src_ptr + 7 * src_step); // H0H1H2H3H4H5H6H7 | |||
uint16x8_t ab_low = vzip1q_u16(src0, src1); // A0B0A1B1A2B2A3B3 | |||
uint16x8_t ab_high = vzip2q_u16(src0, src1); // A4B4A5B5A6B6A7B7 | |||
uint16x8_t cd_low = vzip1q_u16(src2, src3); // C0D0C1D1C2D2C3D3 | |||
uint16x8_t cd_high = vzip2q_u16(src2, src3); // C4D4C5D5C6D6C7D7 | |||
uint16x8_t ef_low = vzip1q_u16(src4, src5); // E0F0E1F1E2F2E3F3 | |||
uint16x8_t ef_high = vzip2q_u16(src4, src5); // E4F4E5F5E6F6E7F7 | |||
uint16x8_t gh_low = vzip1q_u16(src6, src7); // G0H0G1H1G2H2G3H3 | |||
uint16x8_t gh_high = vzip2q_u16(src6, src7); // G4H4G5H5G6H6G7H7 | |||
uint16x8_t abcd_0 = vreinterpretq_u16_u32(vzip1q_u32( | |||
vreinterpretq_u32_u16(ab_low), | |||
vreinterpretq_u32_u16(cd_low))); // A0B0C0D0A1B1C1D1 | |||
uint16x8_t abcd_2 = vreinterpretq_u16_u32(vzip2q_u32( | |||
vreinterpretq_u32_u16(ab_low), | |||
vreinterpretq_u32_u16(cd_low))); // A2B2C2D2A3B3C3D3 | |||
uint16x8_t abcd_4 = vreinterpretq_u16_u32(vzip1q_u32( | |||
vreinterpretq_u32_u16(ab_high), | |||
vreinterpretq_u32_u16(cd_high))); // A4B4C4D4A5B5C5D5 | |||
uint16x8_t abcd_6 = vreinterpretq_u16_u32(vzip2q_u32( | |||
vreinterpretq_u32_u16(ab_high), | |||
vreinterpretq_u32_u16(cd_high))); // A6B6C6D6A7B7C7D7 | |||
uint16x8_t efgh_0 = vreinterpretq_u16_u32(vzip1q_u32( | |||
vreinterpretq_u32_u16(ef_low), | |||
vreinterpretq_u32_u16(gh_low))); // E0F0G0H0E1F1G1H1 | |||
uint16x8_t efgh_2 = vreinterpretq_u16_u32(vzip2q_u32( | |||
vreinterpretq_u32_u16(ef_low), | |||
vreinterpretq_u32_u16(gh_low))); // E2F2G2H2E3F3G3H3 | |||
uint16x8_t efgh_4 = vreinterpretq_u16_u32(vzip1q_u32( | |||
vreinterpretq_u32_u16(ef_high), | |||
vreinterpretq_u32_u16(gh_high))); // E4F4G4H4E5F5G5H5 | |||
uint16x8_t efgh_6 = vreinterpretq_u16_u32(vzip2q_u32( | |||
vreinterpretq_u32_u16(ef_high), | |||
vreinterpretq_u32_u16(gh_high))); // E6F6G6H6E7F7G7H7 | |||
uint16x8_t row_0 = vreinterpretq_u16_u64( | |||
vzip1q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); | |||
uint16x8_t row_1 = vreinterpretq_u16_u64( | |||
vzip2q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); | |||
uint16x8_t row_2 = vreinterpretq_u16_u64( | |||
vzip1q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); | |||
uint16x8_t row_3 = vreinterpretq_u16_u64( | |||
vzip2q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); | |||
uint16x8_t row_4 = vreinterpretq_u16_u64( | |||
vzip1q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); | |||
uint16x8_t row_5 = vreinterpretq_u16_u64( | |||
vzip2q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); | |||
uint16x8_t row_6 = vreinterpretq_u16_u64( | |||
vzip1q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); | |||
uint16x8_t row_7 = vreinterpretq_u16_u64( | |||
vzip2q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); | |||
vst1q_u16(dst_ptr + 0 * dst_step, row_0); | |||
vst1q_u16(dst_ptr + 1 * dst_step, row_1); | |||
vst1q_u16(dst_ptr + 2 * dst_step, row_2); | |||
vst1q_u16(dst_ptr + 3 * dst_step, row_3); | |||
vst1q_u16(dst_ptr + 4 * dst_step, row_4); | |||
vst1q_u16(dst_ptr + 5 * dst_step, row_5); | |||
vst1q_u16(dst_ptr + 6 * dst_step, row_6); | |||
vst1q_u16(dst_ptr + 7 * dst_step, row_7); | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
@@ -148,6 +322,30 @@ void transpose_block<TransposeByte>( | |||
trans_16x16_u8(src, dst, src_stride, dst_stride); | |||
} | |||
template <> | |||
struct transpose_traits<Transpose4Byte> { | |||
static constexpr size_t block_size = 8; | |||
}; | |||
template <> | |||
void transpose_block<Transpose4Byte>( | |||
const Transpose4Byte* src, Transpose4Byte* dst, const size_t src_stride, | |||
const size_t dst_stride) { | |||
trans_8x8_u32(src, dst, src_stride, dst_stride); | |||
} | |||
template <> | |||
struct transpose_traits<Transpose2Byte> { | |||
static constexpr size_t block_size = 8; | |||
}; | |||
template <> | |||
void transpose_block<Transpose2Byte>( | |||
const Transpose2Byte* src, Transpose2Byte* dst, const size_t src_stride, | |||
const size_t dst_stride) { | |||
trans_8x8_u16(src, dst, src_stride, dst_stride); | |||
} | |||
} // namespace transpose_fallback | |||
} // namespace relayout | |||
} // namespace megdnn | |||
@@ -164,16 +362,33 @@ void aarch64::RelayoutForwardImpl::exec( | |||
fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle); | |||
return; | |||
} | |||
relayout::TransposeParam trans_param; | |||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | |||
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | |||
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||
dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | |||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); | |||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||
trans_param.stride_m)); | |||
return; | |||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | |||
auto sptr = static_cast<Transpose2Byte*>(src.raw_ptr), | |||
dptr = static_cast<Transpose2Byte*>(dst.raw_ptr); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | |||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||
trans_param.stride_m)); | |||
return; | |||
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | |||
auto sptr = static_cast<Transpose4Byte*>(src.raw_ptr), | |||
dptr = static_cast<Transpose4Byte*>(dst.raw_ptr); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | |||
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||
trans_param.stride_m)); | |||
return; | |||
} | |||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | |||
} | |||
@@ -321,6 +321,12 @@ __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { | |||
} | |||
#endif | |||
#if !defined(vld1q_u32_x2) && (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 3)) | |||
__ai uint32x4x2_t vld1q_u32_x2(const uint32_t* p) { | |||
return {{vld1q_u32(p), vld1q_u32(p + 4)}}; | |||
} | |||
#endif | |||
__ai int8x16_t vtranslq_s8(int8x8_t a) { | |||
int8x16_t ret; | |||
#if MEGDNN_AARCH64 | |||
@@ -23,7 +23,8 @@ namespace { | |||
//! whether current shape is [b][n][m][c] and is a transpose of contig | |||
//! [b][m][n][c] | |||
bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||
bool is_transpose_single( | |||
const TensorLayout& layout, TransposeParam& p, bool allow_no_contig) { | |||
/* | |||
* assuming contig layout is: | |||
* shape: b, m, n, c | |||
@@ -42,8 +43,9 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||
* | |||
* if b == 1 && c == 1: | |||
* shape: n, m | |||
* stride: 1, n | |||
* stride: 1, n(stride_m for no-contig) | |||
*/ | |||
p.stride_m = 0; | |||
auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; | |||
if (layout.ndim == 4) { | |||
p.batch = layout[0]; | |||
@@ -80,7 +82,15 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { | |||
p.n = layout.shape[0]; | |||
p.m = layout.shape[1]; | |||
p.c = 1; | |||
return strd(0, 1) && strd(1, p.n); | |||
if (strd(0, 1) && strd(1, p.n)) { | |||
return true; | |||
} else if ( | |||
strd(0, 1) && layout.stride[1] > 0 && | |||
(size_t)(layout.stride[1]) >= p.n && allow_no_contig) { | |||
//! stride_m used in no-contig mode, stride_m >= p.n | |||
p.stride_m = layout.stride[1]; | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
@@ -98,15 +108,16 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& | |||
} | |||
bool relayout::is_transpose( | |||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p) { | |||
if (is_contig(dst) && is_transpose_single(src, p)) { | |||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, | |||
bool allow_non_contig) { | |||
if (is_contig(dst) && is_transpose_single(src, p, allow_non_contig)) { | |||
// if the original intention is to transpose (m, n) to (n, m), | |||
// then we should use (n, m) as the contig dst and use a corrsponding | |||
// non-contig src with the same (n, m) shape (remember relayout is | |||
// defined on element correspondence on the logical view) | |||
return true; | |||
} | |||
if (is_contig(src) && is_transpose_single(dst, p)) { | |||
if (is_contig(src) && is_transpose_single(dst, p, allow_non_contig)) { | |||
std::swap(p.m, p.n); | |||
return true; | |||
} | |||
@@ -27,7 +27,7 @@ static inline bool is_contig(const TensorLayout& layout) { | |||
//! [b][m][n][c] to [b][n][m][c] | |||
struct TransposeParam { | |||
size_t batch, m, n, c; | |||
size_t batch, m, n, c, stride_m; | |||
}; | |||
/** | |||
@@ -36,7 +36,9 @@ struct TransposeParam { | |||
* Note that \p src and \p dst should have been processed by | |||
* RelayoutForward::check_layout_and_canonize | |||
*/ | |||
bool is_transpose(const TensorLayout& src, const TensorLayout& dst, TransposeParam& p); | |||
bool is_transpose( | |||
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, | |||
bool allow_non_contig = false); | |||
namespace transpose_fallback { | |||
@@ -105,20 +107,23 @@ void transpose_block( | |||
* \brief transpose contiguous (batch, m, n) to (batch, n, m) | |||
*/ | |||
template <typename T> | |||
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { | |||
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst, size_t stride_m = 0) { | |||
if (stride_m == 0) { | |||
stride_m = n; | |||
} | |||
auto batch_src = src; | |||
auto batch_dst = dst; | |||
constexpr size_t B = transpose_traits<T>::block_size; | |||
auto work_block = [m, n, &batch_src, &batch_dst]( | |||
auto work_block = [m, stride_m, &batch_src, &batch_dst]( | |||
const size_t i, const size_t j, const size_t h, | |||
const size_t w) { | |||
auto src = batch_src + i * n + j, dst = batch_dst + j * m + i; | |||
auto src = batch_src + i * stride_m + j, dst = batch_dst + j * m + i; | |||
MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { | |||
if (h == B && w == B) { | |||
transpose_block(src, dst, n, m); | |||
transpose_block(src, dst, stride_m, m); | |||
} else { | |||
transpose_block(src, dst, n, m, h, w); | |||
transpose_block(src, dst, stride_m, m, h, w); | |||
} | |||
} | |||
MIDOUT_END(); | |||
@@ -141,7 +146,7 @@ void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { | |||
if (i < m) { | |||
work_row(i, m - i); | |||
} | |||
batch_src += m * n; | |||
batch_src += m * stride_m; | |||
batch_dst += m * n; | |||
} | |||
} | |||
@@ -48,10 +48,12 @@ void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) { | |||
} | |||
template <typename T> | |||
void call_transpose(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { | |||
void call_transpose( | |||
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, | |||
size_t stride_m) { | |||
megdnn_assert(ch == 1); | |||
relayout::transpose_fallback::transpose<T>( | |||
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst)); | |||
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst), stride_m); | |||
} | |||
//! one operand contiguous, and the other non-contiguous | |||
@@ -186,7 +188,10 @@ void transpose_cv_row( | |||
} | |||
template <typename ctype> | |||
void transpose_cv(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { | |||
void transpose_cv( | |||
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, | |||
size_t stride_m) { | |||
megdnn_assert(stride_m == 0); | |||
constexpr size_t B = BLOCK_SIZE; | |||
auto batch_src = static_cast<ctype*>(src); | |||
auto batch_dst = static_cast<ctype*>(dst); | |||
@@ -237,7 +242,7 @@ void RelayoutForwardImpl::exec( | |||
} | |||
relayout::TransposeParam trans_param; | |||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | |||
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | |||
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | |||
} | |||
@@ -245,7 +250,7 @@ void RelayoutForwardImpl::exec_after_preprocess( | |||
const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { | |||
if (transpose) { | |||
auto dsize = src.layout.dtype.size() * transpose->c; | |||
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*) = nullptr; | |||
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) = nullptr; | |||
auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr), | |||
dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr); | |||
if (dsize == 1) { | |||
@@ -293,7 +298,9 @@ void RelayoutForwardImpl::exec_after_preprocess( | |||
if (kptr) { | |||
auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr, | |||
kptr]() { kptr(t.batch, t.m, t.n, t.c, sptr, dptr); }; | |||
kptr]() { | |||
kptr(t.batch, t.m, t.n, t.c, sptr, dptr, t.stride_m); | |||
}; | |||
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); | |||
return; | |||
} else { | |||
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/test/aarch64/fixture.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/aarch64/fixture.h" | |||
#include "test/common/memory_manager.h" | |||
#include "test/common/random_state.h" | |||
#include "test/common/utils.h" | |||
namespace megdnn { | |||
namespace test { | |||
Handle* AARCH64::fallback_handle() { | |||
if (!m_fallback_handle) { | |||
m_fallback_handle = create_cpu_handle(1); | |||
} | |||
return m_fallback_handle.get(); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -19,7 +19,13 @@ | |||
namespace megdnn { | |||
namespace test { | |||
class AARCH64 : public ARM_COMMON {}; | |||
class AARCH64 : public ARM_COMMON { | |||
public: | |||
Handle* fallback_handle(); | |||
private: | |||
std::unique_ptr<Handle> m_handle, m_fallback_handle; | |||
}; | |||
class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; | |||
@@ -0,0 +1,118 @@ | |||
/** | |||
* \file dnn/test/aarch64/relayout.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/aarch64/fixture.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/relayout.h" | |||
#include "test/common/rng.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace { | |||
template <typename tag> | |||
class AARCH64_RELAYOUT : public AARCH64 {}; | |||
TYPED_TEST_CASE(AARCH64_RELAYOUT, relayout::test_types); | |||
TYPED_TEST(AARCH64_RELAYOUT, run) { | |||
relayout::run_test<TypeParam>(this->handle()); | |||
} | |||
} // namespace | |||
TEST_F(AARCH64, Relayout) { | |||
Checker<Relayout> checker(handle()); | |||
std::vector<::megdnn::DType> dtype_vec; | |||
dtype_vec.push_back(dtype::Float32()); | |||
dtype_vec.push_back(dtype::Int16()); | |||
dtype_vec.push_back(dtype::Uint16()); | |||
dtype_vec.push_back(dtype::Int8()); | |||
for (auto dtype : dtype_vec) { | |||
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); | |||
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | |||
checker.execl({src, dst}); | |||
} | |||
} | |||
TEST_F(AARCH64, RelayoutBig) { | |||
Checker<Relayout> checker(handle()); | |||
ConsecutiveRNG rng; | |||
checker.set_rng(0, &rng); | |||
int m = 512; | |||
int n = 512; | |||
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype::Float32()); | |||
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype::Float32()); | |||
checker.execl({src, dst}); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(AARCH64, BENCHMARK_Relayout) { | |||
constexpr size_t WARM_RUNS = 100; | |||
constexpr size_t RUNS = 600; | |||
auto dtype = dtype::Float32(); | |||
Benchmarker<Relayout> benchmarker_relayout(handle()); | |||
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle()); | |||
benchmarker_relayout.set_times(WARM_RUNS); | |||
benchmarker_fbk_relayout.set_times(WARM_RUNS); | |||
int m = 512; | |||
int n = 512; | |||
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype); | |||
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype); | |||
TensorLayoutArray tensor_case; | |||
tensor_case.push_back(src); | |||
tensor_case.push_back(dst); | |||
benchmarker_relayout.exec(tensor_case); | |||
benchmarker_fbk_relayout.exec(tensor_case); | |||
benchmarker_relayout.set_times(RUNS); | |||
benchmarker_fbk_relayout.set_times(RUNS); | |||
auto used = benchmarker_relayout.exec(tensor_case) / RUNS; | |||
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; | |||
float bw = 2.f * m * n * 1e-6 / used * dtype.size(); | |||
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); | |||
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", | |||
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); | |||
} | |||
TEST_F(AARCH64, BENCHMARK_Relayout_2) { | |||
constexpr size_t WARM_RUNS = 100; | |||
constexpr size_t RUNS = 600; | |||
auto dtype = dtype::Float32(); | |||
Benchmarker<Relayout> benchmarker_relayout(handle()); | |||
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle()); | |||
benchmarker_relayout.set_times(WARM_RUNS); | |||
benchmarker_fbk_relayout.set_times(WARM_RUNS); | |||
int m = 54; | |||
int n = 28762; | |||
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); | |||
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); | |||
TensorLayoutArray tensor_case; | |||
tensor_case.push_back(src); | |||
tensor_case.push_back(dst); | |||
benchmarker_relayout.exec(tensor_case); | |||
benchmarker_fbk_relayout.exec(tensor_case); | |||
benchmarker_relayout.set_times(RUNS); | |||
benchmarker_fbk_relayout.set_times(RUNS); | |||
auto used = benchmarker_relayout.exec(tensor_case) / RUNS; | |||
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; | |||
float bw = 2.f * m * n * 1e-6 / used * dtype.size(); | |||
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); | |||
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", | |||
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -180,11 +180,11 @@ TEST(RELAYOUT, TRANSPOSE_DET) { | |||
ASSERT_EQ(p_get.c, p.c); | |||
} | |||
}; | |||
run({2, 3}, {1, 0}, true, {1, 2, 3, 1}); | |||
run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5}); | |||
run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1}); | |||
run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5}); | |||
run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1}); | |||
run({2, 3}, {1, 0}, true, {1, 2, 3, 1, 0}); | |||
run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5, 0}); | |||
run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1, 0}); | |||
run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5, 0}); | |||
run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1, 0}); | |||
run({2, 3, 5}, {2, 1, 0}, false); | |||
run({3, 2, 3, 5}, {3, 2, 1, 0}, false); | |||
} | |||