From 95657d54cf1e08893d0769d4e6d662c75e74cadd Mon Sep 17 00:00:00 2001 From: jieli-matrix Date: Mon, 6 Dec 2021 23:15:28 +0800 Subject: [PATCH 1/5] docs(mge/functional): update functional.math.svd docstring --- imperative/python/megengine/functional/math.py | 58 ++++++++++++++++---------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index facd8206..87343847 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -1151,36 +1151,52 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor: return result -def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: - r"""Computes the singular value decompositions of input matrix. +def svd(x: Tensor, full_matrices=False, compute_uv=True) -> Tensor: + r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows. Args: - inp: input matrix, must has shape `[..., M, N]`. + x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` . + full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` . + compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` . Returns: - output matrices, `(U, sigma, V)`. + Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True). + ``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` . - Examples: - - .. testcode:: - - import numpy as np - from megengine import tensor - import megengine.functional as F - - x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) - _, y, _ = F.svd(x) - print(y.numpy().round(decimals=3)) - - Outputs: + ``S`` contains the vector(s) of singular values of length ``K`` , where ``K = min(M, N)`` . For each vector, the singular values must be sorted in descending order by magnitude, such that ``s[..., 0]`` is the largest value, ``s[..., 1]`` is the second largest value, etc. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x`` . - .. testoutput:: + ``Vh`` contains orthonormal rows (i.e., the rows are the right singular vectors and the array is the adjoint). If ``full_matrices`` is ``True`` , the array must have shape ``(..., N, N)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., K, N)`` where ``K = min(M, N)`` . The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x`` . + Each returned array must have the same floating-point data type as ``x`` . - [7.348 1. ] + Examples: + >>> import numpy as np + >>> x = Tensor(np.random.randn(9, 6)) + >>> y = Tensor(np.random.randn(2, 7, 8, 3)) + + Reconstruction based on full SVD, 2D case: + >>> U, S, Vh = F.svd(x, full_matrices=True) + >>> U.shape, S.shape, Vh.shape + ((9, 9), (6,), (6, 6)) + + Reconstruction based on reduced SVD, 2D case: + >>> U, S, Vh = F.svd(x, full_matrices=False) + >>> U.shape, S.shape, Vh.shape + ((9, 6), (6,), (6, 6)) + + Reconsturction based on full SVD, 4D case: + >>> u, s, vh = F.svd(y, full_matrices=True) + >>> u.shape, s.shape, vh.shape + ((2, 7, 8, 8), (2, 7, 3), (2, 7, 3, 3)) + + Reconsturction based on reduced SVD, 4D case: + >>> u, s, vh = F.svd(y, full_matrices=False) + >>> u.shape, s.shape, vh.shape + ((2, 7, 8, 3), (2, 7, 3), (2, 7, 3, 3)) + """ op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) - U, sigma, V = apply(op, inp) - return U, sigma, V + U, S, Vh = apply(op, x) + return U, S, Vh def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor: From 3159eecadd93f4803dea5c36fbd823ec344e63a6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 17:23:32 +0800 Subject: [PATCH 2/5] fix(init): fix fan_in and fan_out for group conv2d GitOrigin-RevId: a6f41063f081c06710dd0c157ff9794bae57bab9 --- imperative/python/megengine/module/init.py | 29 ++++++++++++++++--------- imperative/python/test/unit/module/test_init.py | 29 ++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/imperative/python/megengine/module/init.py b/imperative/python/megengine/module/init.py index 84834755..2bf73fde 100644 --- a/imperative/python/megengine/module/init.py +++ b/imperative/python/megengine/module/init.py @@ -74,7 +74,7 @@ def calculate_gain( ) -> float: r"""Returns a recommended gain value (see the table below) for the given nonlinearity function. - + ================= ==================================================== nonlinearity gain ================= ==================================================== @@ -126,6 +126,11 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: r"""Calculates fan_in / fan_out value for given weight tensor. This function assumes input tensor is stored in ``NCHW`` format. + Note: + The group conv2d kernel shape in MegEngine is ``(G, O/G, I/G, K, K)``. This + function calculates ``fan_out = O/G * K * K`` as default, but PyTorch uses + ``fan_out = O * K * K``. + Args: tensor: weight tensor in ``NCHW`` format. """ @@ -141,6 +146,10 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: fan_in = shape[1] fan_out = shape[0] else: + if ndim >= 5: + # ignore the groups dimension of group conv2d and group conv3d + # FIXME: will be wrong for conv3d + shape = shape[1:] num_input_fmaps = shape[1] num_output_fmaps = shape[0] receptive_field_size = 1 @@ -154,7 +163,7 @@ def calculate_fan_in_and_fan_out(tensor: Tensor) -> Tuple[float, float]: def calculate_correct_fan(tensor: Tensor, mode: str) -> float: r"""Calculates fan_in / fan_out value for given weight tensor, depending on given ``mode``. - + See :func:`calculate_fan_in_and_fan_out` for details. Args: @@ -175,11 +184,11 @@ def calculate_correct_fan(tensor: Tensor, mode: str) -> float: def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: r"""Fills tensor with random values sampled from :math:`\mathcal{U}(-a, a)` where - + .. math:: a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}} - + Also known as Glorot initialization. Detailed information can be retrieved from `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). @@ -197,11 +206,11 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.0) -> None: def xavier_normal_(tensor: Tensor, gain: float = 1.0) -> None: r"""Fills tensor with random values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where - + .. math:: \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}} - + Also known as Glorot initialization. Detailed information can be retrieved from `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). @@ -220,11 +229,11 @@ def msra_uniform_( ) -> None: r"""Fills tensor wilth random values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where - + .. math:: \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan_in}}} - + Detailed information can be retrieved from `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` @@ -251,11 +260,11 @@ def msra_normal_( ) -> None: r"""Fills tensor wilth random values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where - + .. math:: \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}} - + Detailed information can be retrieved from `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` diff --git a/imperative/python/test/unit/module/test_init.py b/imperative/python/test/unit/module/test_init.py index 9f3a019e..b28f60e1 100644 --- a/imperative/python/test/unit/module/test_init.py +++ b/imperative/python/test/unit/module/test_init.py @@ -10,7 +10,7 @@ import numpy as np import pytest from megengine import tensor -from megengine.module import Conv2d, Linear +from megengine.module import Conv1d, Conv2d, Conv3d, Linear from megengine.module.init import calculate_fan_in_and_fan_out, fill_ @@ -32,7 +32,34 @@ def test_calculate_fan_in_and_fan_out(): with pytest.raises(ValueError): calculate_fan_in_and_fan_out(l.bias) + l = Conv1d(in_channels=2, out_channels=3, kernel_size=5) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 * 5 + assert fanout == 3 * 5 + + # FIXME: will be wrong for group conv1d + # l = Conv1d(in_channels=2, out_channels=4, kernel_size=5, groups=2) + # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + # assert fanin == 2 // 2 * 5 + # assert fanout == 4 // 2 * 5 + l = Conv2d(in_channels=2, out_channels=3, kernel_size=(5, 7)) fanin, fanout = calculate_fan_in_and_fan_out(l.weight) assert fanin == 2 * 5 * 7 assert fanout == 3 * 5 * 7 + + l = Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 7), groups=2) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 // 2 * 5 * 7 + assert fanout == 4 // 2 * 5 * 7 + + # FIXME: will be wrong for conv3d + # l = Conv3d(in_channels=2, out_channels=3, kernel_size=(5, 7, 9)) + # fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + # assert fanin == 2 * 5 * 7 * 9 + # assert fanout == 3 * 5 * 7 * 9 + + l = Conv3d(in_channels=2, out_channels=4, kernel_size=(5, 7, 9), groups=2) + fanin, fanout = calculate_fan_in_and_fan_out(l.weight) + assert fanin == 2 // 2 * 5 * 7 * 9 + assert fanout == 4 // 2 * 5 * 7 * 9 From 17f2dffb5bd7cfe70d3f8072d5b312470df85d7b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 21 Jan 2022 16:30:06 +0800 Subject: [PATCH 3/5] fix(imperative/cpu/multithread): fix multithread at imperative GitOrigin-RevId: 9120d5cb48e1087f38a826d377ca3c6bae984730 --- imperative/python/megengine/device.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index 9ba1c4d6..d9a8f2a7 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -50,7 +50,9 @@ _sh = _stream_helper() def _valid_device(inp): - if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): + if isinstance(inp, str) and re.match( + "^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp + ): return True return False From b04c3d145671981ab478f3d4d80741ebb37d22cc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 21 Jan 2022 11:24:40 +0800 Subject: [PATCH 4/5] feat(lite): add set address ptr pair interface GitOrigin-RevId: 285dacb4da51cb1e23f411967d612bb520d611d8 --- .gitattributes | 3 +++ lite/include/lite/global.h | 15 +++++++++++ lite/lite-c/include/lite-c/global_c.h | 17 +++++++++++- lite/lite-c/src/global.cpp | 15 +++++++++++ lite/pylite/megenginelite/global_setting.py | 20 ++++++++++++++ ...test_network_cuda.py => test_network_device.py} | 0 lite/src/global.cpp | 31 ++++++++++++++++++++++ lite/test/test_network.cpp | 1 + 8 files changed, 101 insertions(+), 1 deletion(-) rename lite/pylite/test/{test_network_cuda.py => test_network_device.py} (100%) diff --git a/.gitattributes b/.gitattributes index 458eb5aa..0b84a4c2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -21,3 +21,6 @@ ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text +lite/test/resource/lite/ax_data_input.npy filter=lfs diff=lfs merge=lfs -text +lite/test/resource/lite/ax_data_output.npy filter=lfs diff=lfs merge=lfs -text +lite/test/resource/lite/ax_model.mge filter=lfs diff=lfs merge=lfs -text diff --git a/lite/include/lite/global.h b/lite/include/lite/global.h index e681ee7e..f9c70777 100644 --- a/lite/include/lite/global.h +++ b/lite/include/lite/global.h @@ -154,6 +154,21 @@ LITE_API void set_tensor_rt_cache(std::string tensorrt_cache_path); */ LITE_API void dump_tensor_rt_cache(); +/** + * register the physical and virtual address pair to the mge, some device + * need the map from physical to virtual. + */ +LITE_API bool register_memory_pair( + void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, + LiteBackend backend = LiteBackend::LITE_DEFAULT); + +/** + * clear the physical and virtual address pair in mge. + */ +LITE_API bool clear_memory_pair( + void* vir_ptr, void* phy_ptr, LiteDeviceType device, + LiteBackend backend = LiteBackend::LITE_DEFAULT); + } // namespace lite // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/lite-c/include/lite-c/global_c.h b/lite/lite-c/include/lite-c/global_c.h index a895f28c..42eed593 100644 --- a/lite/lite-c/include/lite-c/global_c.h +++ b/lite/lite-c/include/lite-c/global_c.h @@ -160,9 +160,24 @@ LITE_API int LITE_dump_persistent_cache(const char* cache_path); * \brief dump the tensorrt policy cache to file */ LITE_API int LITE_dump_tensor_rt_cache(); -#endif + +/** + * register the physical and virtual address pair to the mge, some device + * need the map from physical to virtual. + */ +LITE_API int LITE_register_memory_pair( + void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, + LiteBackend backend); + +/** + * clear the physical and virtual address pair in mge. + */ +LITE_API int LITE_clear_memory_pair( + void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend); + #ifdef __cplusplus } #endif +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/lite-c/src/global.cpp b/lite/lite-c/src/global.cpp index c686b1f3..8be2644c 100644 --- a/lite/lite-c/src/global.cpp +++ b/lite/lite-c/src/global.cpp @@ -189,4 +189,19 @@ int LITE_dump_tensor_rt_cache() { LITE_CAPI_END(); } +int LITE_register_memory_pair( + void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, + LiteBackend backend) { + LITE_CAPI_BEGIN(); + lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend); + LITE_CAPI_END(); +} + +int LITE_clear_memory_pair( + void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) { + LITE_CAPI_BEGIN(); + lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend); + LITE_CAPI_END(); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/pylite/megenginelite/global_setting.py b/lite/pylite/megenginelite/global_setting.py index c39cdf62..89615e6b 100644 --- a/lite/pylite/megenginelite/global_setting.py +++ b/lite/pylite/megenginelite/global_setting.py @@ -42,6 +42,8 @@ class _GlobalAPI(_LiteCObjBase): # ('LITE_set_tensor_rt_cache', [c_char_p]), ("LITE_dump_persistent_cache", [c_char_p]), ("LITE_dump_tensor_rt_cache", [c_char_p]), + ("LITE_register_memory_pair", [c_void_p, c_void_p, c_size_t, c_int, c_int]), + ("LITE_clear_memory_pair", [c_void_p, c_void_p, c_int, c_int]), ] @@ -121,3 +123,21 @@ class LiteGlobal(object): @staticmethod def try_coalesce_all_free_memory(): LiteGlobal._api.LITE_try_coalesce_all_free_memory() + + @staticmethod + def register_memory_pair( + vir_ptr, phy_ptr, length, device, backend=LiteBackend.LITE_DEFAULT + ): + assert isinstance(vir_ptr, c_void_p) and isinstance( + phy_ptr, c_void_p + ), "clear memory pair only accept c_void_p type." + LiteGlobal._api.LITE_register_memory_pair( + vir_ptr, phy_ptr, length, device, backend + ) + + @staticmethod + def clear_memory_pair(vir_ptr, phy_ptr, device, backend=LiteBackend.LITE_DEFAULT): + assert isinstance(vir_ptr, c_void_p) and isinstance( + phy_ptr, c_void_p + ), "clear memory pair only accept c_void_p type." + LiteGlobal._api.LITE_clear_memory_pair(vir_ptr, phy_ptr, device, backend) diff --git a/lite/pylite/test/test_network_cuda.py b/lite/pylite/test/test_network_device.py similarity index 100% rename from lite/pylite/test/test_network_cuda.py rename to lite/pylite/test/test_network_device.py diff --git a/lite/src/global.cpp b/lite/src/global.cpp index 5aa973a7..9f3e9fab 100644 --- a/lite/src/global.cpp +++ b/lite/src/global.cpp @@ -212,6 +212,26 @@ void lite::dump_tensor_rt_cache() { #endif } +bool lite::register_memory_pair( + void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, + LiteBackend backend) { + LITE_MARK_USED_VAR(vir_ptr); + LITE_MARK_USED_VAR(phy_ptr); + LITE_MARK_USED_VAR(length); + LITE_MARK_USED_VAR(device); + LITE_MARK_USED_VAR(backend); + LITE_THROW("register_memory_pair is not implement yet!"); +} + +bool lite::clear_memory_pair( + void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) { + LITE_MARK_USED_VAR(vir_ptr); + LITE_MARK_USED_VAR(phy_ptr); + LITE_MARK_USED_VAR(device); + LITE_MARK_USED_VAR(backend); + LITE_THROW("clear_memory_pair is not implement yet!"); +} + #else // LITE_BUILD_WITH_MGE void lite::try_coalesce_all_free_memory() {} @@ -235,6 +255,17 @@ void lite::set_tensor_rt_cache(std::string) { void lite::dump_tensor_rt_cache() { LITE_THROW("mge is disbale at build time, please build with mge"); } + +bool lite::register_memory_pair( + void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device, + LiteBackend beckend) { + LITE_THROW("register_memory_pair is not implement yet!"); +} + +bool lite::clear_memory_pair( + void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend beckend) { + LITE_THROW("clear_memory_pair is not implement yet!"); +} #endif namespace lite { REGIST_DECRYPTION_METHOD( diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index c7cab766..8734e8ee 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -1357,5 +1357,6 @@ TEST(TestNetWork, CambriconDeviceID) { load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb"); } #endif + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} From 260923e11cba2943762b9c137e76865ddcb1ddf3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 Jan 2022 11:12:27 +0800 Subject: [PATCH 5/5] perf(aarch64): optimize aarch64 uint16 relayout with block_w==3 GitOrigin-RevId: fe6aaaac0cbd594ad80dcf2e9763fce9c99f5a4e --- dnn/src/aarch64/relayout/opr_impl.cpp | 54 +++++++++++++++++++++++++++++++++++ dnn/test/aarch64/relayout.cpp | 3 ++ 2 files changed, 57 insertions(+) diff --git a/dnn/src/aarch64/relayout/opr_impl.cpp b/dnn/src/aarch64/relayout/opr_impl.cpp index dfcf5036..c23cbaa3 100644 --- a/dnn/src/aarch64/relayout/opr_impl.cpp +++ b/dnn/src/aarch64/relayout/opr_impl.cpp @@ -363,6 +363,58 @@ static inline void trans_8x4_u16( vst1q_u16(dst_ptr + 3 * dst_step, row_3); } +static inline void trans_8x3_u16( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { + uint16_t* src_ptr = (uint16_t*)src; + uint16_t* dst_ptr = (uint16_t*)dst; + uint16x4_t src0 = vld1_u16(src_ptr + 0 * src_step); // A0A1A2A3 + uint16x4_t src1 = vld1_u16(src_ptr + 1 * src_step); // B0B1B2B3 + uint16x4_t src2 = vld1_u16(src_ptr + 2 * src_step); // C0C1C2C3 + uint16x4_t src3 = vld1_u16(src_ptr + 3 * src_step); // D0D1D2D3 + uint16x4_t src4 = vld1_u16(src_ptr + 4 * src_step); // E0E1E2E3 + uint16x4_t src5 = vld1_u16(src_ptr + 5 * src_step); // F0F1F2F3 + uint16x4_t src6 = vld1_u16(src_ptr + 6 * src_step); // G0G1G2G3 + // H0H1H2 + uint16x4_t src7 = + vreinterpret_u16_u32(vld1_dup_u32((uint32_t*)(src_ptr + 7 * src_step))); + src7 = vld1_lane_u16(src_ptr + 7 * src_step + 2, src7, 2); + + uint16x4_t ab_low = vzip1_u16(src0, src1); // A0B0A1B1 + uint16x4_t ab_high = vzip2_u16(src0, src1); // A2B2A3B3 + uint16x4_t cd_low = vzip1_u16(src2, src3); // C0D0C1D1 + uint16x4_t cd_high = vzip2_u16(src2, src3); // C2D2C3D3 + uint16x4_t ef_low = vzip1_u16(src4, src5); // E0F0E1F1 + uint16x4_t ef_high = vzip2_u16(src4, src5); // E2F2E3F3 + uint16x4_t gh_low = vzip1_u16(src6, src7); // G0H0G1H1 + uint16x4_t gh_high = vzip2_u16(src6, src7); // G2H2G3 + + uint16x4_t abcd_0 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ab_low), + vreinterpret_u32_u16(cd_low))); // A0B0C0D0 + uint16x4_t abcd_1 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ab_low), + vreinterpret_u32_u16(cd_low))); // A1B1C1D1 + uint16x4_t abcd_2 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ab_high), + vreinterpret_u32_u16(cd_high))); // A2B2C2D2 + uint16x4_t efgh_0 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ef_low), + vreinterpret_u32_u16(gh_low))); // E0F0G0H0 + uint16x4_t efgh_1 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ef_low), + vreinterpret_u32_u16(gh_low))); // E1F1G1H1 + uint16x4_t efgh_2 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ef_high), + vreinterpret_u32_u16(gh_high))); // E2F2G2H2 + + uint16x8_t row_0 = vcombine_u16(abcd_0, efgh_0); + uint16x8_t row_1 = vcombine_u16(abcd_1, efgh_1); + uint16x8_t row_2 = vcombine_u16(abcd_2, efgh_2); + + vst1q_u16(dst_ptr + 0 * dst_step, row_0); + vst1q_u16(dst_ptr + 1 * dst_step, row_1); + vst1q_u16(dst_ptr + 2 * dst_step, row_2); +} } // anonymous namespace namespace megdnn { @@ -410,6 +462,8 @@ void transpose_block( const size_t dst_stride, size_t block_h, size_t block_w) { if (block_h == 8 && block_w == 4) { trans_8x4_u16(src, dst, src_stride, dst_stride); + } else if (block_h == 8 && block_w == 3) { + trans_8x3_u16(src, dst, src_stride, dst_stride); } else { transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); } diff --git a/dnn/test/aarch64/relayout.cpp b/dnn/test/aarch64/relayout.cpp index 3a604501..3e5a37b1 100644 --- a/dnn/test/aarch64/relayout.cpp +++ b/dnn/test/aarch64/relayout.cpp @@ -40,6 +40,9 @@ TEST_F(AARCH64, Relayout) { TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); checker.execl({src, dst}); } + TensorLayout src_4_3({1, 3, 112, 256}, {3, 1, 1024, 4}, dtype::Uint16()); + TensorLayout dst_4_3({1, 3, 112, 256}, {86016, 28672, 256, 1}, dtype::Uint16()); + checker.execl({src_4_3, dst_4_3}); } TEST_F(AARCH64, RelayoutNonContig) {