Browse Source

style(all): reformat c++ code

GitOrigin-RevId: 3ffd1b211f
release-1.7
Megvii Engine Team 3 years ago
parent
commit
369c2ccc5a
100 changed files with 6469 additions and 6810 deletions
  1. +0
    -0
      dnn/atlas-stub/include/acl/acl.h
  2. +0
    -0
      dnn/atlas-stub/include/acl/acl_base.h
  3. +0
    -0
      dnn/atlas-stub/include/acl/acl_mdl.h
  4. +0
    -0
      dnn/atlas-stub/include/acl/acl_op.h
  5. +0
    -0
      dnn/atlas-stub/include/acl/acl_rt.h
  6. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_cblas.h
  7. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_dvpp.h
  8. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_fv.h
  9. +2
    -2
      dnn/include/hip_header.h
  10. +52
    -66
      dnn/include/megcore.h
  11. +5
    -6
      dnn/include/megcore_atlas.h
  12. +2
    -3
      dnn/include/megcore_cambricon.h
  13. +1
    -2
      dnn/include/megcore_cdefs.h
  14. +3
    -4
      dnn/include/megcore_cuda.h
  15. +6
    -5
      dnn/include/megcore_rocm.h
  16. +1
    -1
      dnn/include/megdnn.h
  17. +73
    -74
      dnn/include/megdnn/arch.h
  18. +33
    -42
      dnn/include/megdnn/basic_types.h
  19. +6
    -7
      dnn/include/megdnn/common.h
  20. +4
    -4
      dnn/include/megdnn/cuda.h
  21. +353
    -463
      dnn/include/megdnn/dtype.h
  22. +18
    -13
      dnn/include/megdnn/dtype/half_common_epilogue.h
  23. +147
    -144
      dnn/include/megdnn/dtype/half_common_prologue.h
  24. +134
    -137
      dnn/include/megdnn/handle.h
  25. +3
    -2
      dnn/include/megdnn/heuristic_cache.h
  26. +5
    -6
      dnn/include/megdnn/internal/defs.h
  27. +20
    -19
      dnn/include/megdnn/internal/opr_header_prologue.h
  28. +0
    -1
      dnn/include/megdnn/internal/visibility_epilogue.h
  29. +16
    -20
      dnn/include/megdnn/opr_result_defs.h
  30. +2
    -4
      dnn/include/megdnn/oprs.h
  31. +91
    -124
      dnn/include/megdnn/oprs/base.h
  32. +123
    -99
      dnn/include/megdnn/oprs/cv.h
  33. +847
    -822
      dnn/include/megdnn/oprs/general.h
  34. +164
    -180
      dnn/include/megdnn/oprs/imgproc.h
  35. +55
    -54
      dnn/include/megdnn/oprs/linalg.h
  36. +629
    -618
      dnn/include/megdnn/oprs/nn.h
  37. +5
    -8
      dnn/include/megdnn/oprs/nn_int.h
  38. +103
    -87
      dnn/include/megdnn/oprs/utils.h
  39. +36
    -47
      dnn/include/megdnn/tensor_format.h
  40. +3
    -5
      dnn/include/megdnn/tensor_iter.h
  41. +5
    -5
      dnn/include/megdnn/thin/function.h
  42. +42
    -76
      dnn/include/megdnn/thin/small_vector.h
  43. +6
    -6
      dnn/include/megdnn/version.h
  44. +26
    -24
      dnn/src/aarch64/conv_bias/fp16/algos.cpp
  45. +5
    -5
      dnn/src/aarch64/conv_bias/fp16/algos.h
  46. +48
    -67
      dnn/src/aarch64/conv_bias/fp16/stride2_kern.h
  47. +30
    -31
      dnn/src/aarch64/conv_bias/fp32/algos.cpp
  48. +5
    -5
      dnn/src/aarch64/conv_bias/fp32/algos.h
  49. +33
    -38
      dnn/src/aarch64/conv_bias/fp32/stride2_kern.h
  50. +36
    -37
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  51. +6
    -8
      dnn/src/aarch64/conv_bias/int8/algos.h
  52. +61
    -70
      dnn/src/aarch64/conv_bias/int8/strategy.cpp
  53. +25
    -26
      dnn/src/aarch64/conv_bias/int8/strategy.h
  54. +12
    -13
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  55. +1
    -1
      dnn/src/aarch64/conv_bias/opr_impl.h
  56. +36
    -37
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  57. +6
    -8
      dnn/src/aarch64/conv_bias/quint8/algos.h
  58. +93
    -97
      dnn/src/aarch64/conv_bias/quint8/strategy.cpp
  59. +26
    -28
      dnn/src/aarch64/conv_bias/quint8/strategy.h
  60. +4
    -4
      dnn/src/aarch64/handle.cpp
  61. +10
    -12
      dnn/src/aarch64/handle.h
  62. +327
    -438
      dnn/src/aarch64/matrix_mul/algos.cpp
  63. +28
    -76
      dnn/src/aarch64/matrix_mul/algos.h
  64. +393
    -458
      dnn/src/aarch64/matrix_mul/asm/common.h
  65. +534
    -555
      dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
  66. +4
    -4
      dnn/src/aarch64/matrix_mul/fp16/strategy.h
  67. +21
    -20
      dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp
  68. +13
    -11
      dnn/src/aarch64/matrix_mul/fp32/common.h
  69. +124
    -118
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h
  70. +54
    -60
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
  71. +34
    -37
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h
  72. +34
    -37
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h
  73. +28
    -27
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
  74. +23
    -22
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h
  75. +23
    -22
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h
  76. +83
    -86
      dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  77. +5
    -8
      dnn/src/aarch64/matrix_mul/fp32/strategy.h
  78. +19
    -23
      dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp
  79. +99
    -96
      dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h
  80. +35
    -36
      dnn/src/aarch64/matrix_mul/int16/strategy.cpp
  81. +4
    -4
      dnn/src/aarch64/matrix_mul/int16/strategy.h
  82. +23
    -21
      dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp
  83. +121
    -103
      dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h
  84. +33
    -34
      dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
  85. +2
    -2
      dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
  86. +57
    -39
      dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
  87. +163
    -113
      dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
  88. +47
    -41
      dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
  89. +77
    -82
      dnn/src/aarch64/matrix_mul/int8/strategy.cpp
  90. +6
    -6
      dnn/src/aarch64/matrix_mul/int8/strategy.h
  91. +62
    -59
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
  92. +30
    -32
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
  93. +59
    -59
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
  94. +5
    -5
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
  95. +54
    -38
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h
  96. +159
    -111
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h
  97. +34
    -41
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h
  98. +18
    -22
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h
  99. +48
    -55
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
  100. +128
    -142
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp

+ 0
- 0
dnn/atlas-stub/include/acl/acl.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_base.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_mdl.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_op.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_rt.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_cblas.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_dvpp.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_fv.h View File


+ 2
- 2
dnn/include/hip_header.h View File

@@ -23,9 +23,9 @@
#pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Wsign-compare" #pragma GCC diagnostic ignored "-Wsign-compare"
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#pragma GCC diagnostic pop #pragma GCC diagnostic pop


#if !defined(__HIP_PLATFORM_HCC__) #if !defined(__HIP_PLATFORM_HCC__)


+ 52
- 66
dnn/include/megcore.h View File

@@ -11,10 +11,10 @@


#pragma once #pragma once


#include "megdnn/thin/function.h"
#include "megcore_cdefs.h"
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include "megcore_cdefs.h"
#include "megdnn/thin/function.h"


#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"


@@ -26,36 +26,35 @@ namespace megcore {
* the caller thread immediately. * the caller thread immediately.
*/ */
class CPUDispatcher { class CPUDispatcher {
public:
using Task = megdnn::thin_function<void()>;
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>;
virtual ~CPUDispatcher() noexcept;

/*!
* \brief dispatch a task on the computing thread
* \param task the task that would be moved away
*/
virtual void dispatch(Task&& task) = 0;

/*!
* \brief dispatch a multithreading task on the computing thread
* \param task the task would be moved away
* \param parallelism the parallelism of the task.
*/
virtual void dispatch(MultiThreadingTask&& task,
size_t parallelism) = 0;

/*!
* \brief synchronize the calling thread with the computing thread
*/
virtual void sync() = 0;

/*!
* \brief the computing thread number.
*/
virtual size_t nr_threads() = 0;
public:
using Task = megdnn::thin_function<void()>;
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>;
virtual ~CPUDispatcher() noexcept;

/*!
* \brief dispatch a task on the computing thread
* \param task the task that would be moved away
*/
virtual void dispatch(Task&& task) = 0;

/*!
* \brief dispatch a multithreading task on the computing thread
* \param task the task would be moved away
* \param parallelism the parallelism of the task.
*/
virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0;

/*!
* \brief synchronize the calling thread with the computing thread
*/
virtual void sync() = 0;

/*!
* \brief the computing thread number.
*/
virtual size_t nr_threads() = 0;
}; };
} // namespace megcore
} // namespace megcore


using MegcoreCPUDispatcher = megcore::CPUDispatcher; using MegcoreCPUDispatcher = megcore::CPUDispatcher;


@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher;
* \brief Layer 1: device handle * \brief Layer 1: device handle
*/ */
struct megcoreDeviceContext; struct megcoreDeviceContext;
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t;
typedef struct megcoreDeviceContext* megcoreDeviceHandle_t;


megcoreStatus_t megcoreCreateDeviceHandle( megcoreStatus_t megcoreCreateDeviceHandle(
megcoreDeviceHandle_t *handle,
megcorePlatform_t platform,
int deviceID = -1,
megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1,
unsigned int flags = 0); unsigned int flags = 0);
megcoreStatus_t megcoreDestroyDeviceHandle(
megcoreDeviceHandle_t handle);

megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle,
megcorePlatform_t *platform);
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle,
int *deviceID);
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle,
size_t *memAlignmentInBytes);
megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle);

megcoreStatus_t megcoreGetPlatform(
megcoreDeviceHandle_t handle, megcorePlatform_t* platform);
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID);
megcoreStatus_t megcoreGetMemAlignment(
megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes);
megcoreStatus_t megcoreGetDeviceFlags( megcoreStatus_t megcoreGetDeviceFlags(
megcoreDeviceHandle_t handle,
unsigned int *flags);
megcoreDeviceHandle_t handle, unsigned int* flags);


megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle,
void **devPtr, size_t sizeInBytes);
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle,
void *devPtr);
megcoreStatus_t megcoreMalloc(
megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes);
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr);


/** /**
* \brief Layer 2: computing handle * \brief Layer 2: computing handle
*/ */
struct megcoreComputingContext; struct megcoreComputingContext;
typedef struct megcoreComputingContext *megcoreComputingHandle_t;
typedef struct megcoreComputingContext* megcoreComputingHandle_t;


megcoreStatus_t megcoreCreateComputingHandle( megcoreStatus_t megcoreCreateComputingHandle(
megcoreComputingHandle_t *compHandle,
megcoreDeviceHandle_t devHandle,
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags = 0); unsigned int flags = 0);


megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher(
megcoreComputingHandle_t *compHandle,
megcoreDeviceHandle_t devHandle,
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher,
unsigned int flags = 0); unsigned int flags = 0);


megcoreStatus_t megcoreDestroyComputingHandle(
megcoreComputingHandle_t handle);
megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle);


megcoreStatus_t megcoreGetDeviceHandle( megcoreStatus_t megcoreGetDeviceHandle(
megcoreComputingHandle_t compHandle,
megcoreDeviceHandle_t *devHandle);
megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle);
megcoreStatus_t megcoreGetComputingFlags( megcoreStatus_t megcoreGetComputingFlags(
megcoreComputingHandle_t handle,
unsigned int *flags);
megcoreComputingHandle_t handle, unsigned int* flags);


MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle);


megcoreStatus_t megcoreMemcpy( megcoreStatus_t megcoreMemcpy(
megcoreComputingHandle_t handle,
void *dst, const void *src, size_t sizeInBytes,
megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes,
megcoreMemcpyKind_t kind); megcoreMemcpyKind_t kind);
megcoreStatus_t megcoreMemset( megcoreStatus_t megcoreMemset(
megcoreComputingHandle_t handle,
void *dst, int value, size_t sizeInBytes);
megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes);
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle);


/** /**
* \brief Miscellaneous * \brief Miscellaneous
*/ */
const char *megcoreGetErrorName(megcoreStatus_t status);
const char* megcoreGetErrorName(megcoreStatus_t status);


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"




+ 5
- 6
dnn/include/megcore_atlas.h View File

@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const AtlasContext& ctx); unsigned int flags, const AtlasContext& ctx);


megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle,
AtlasContext* ctx);
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx);


namespace atlas { namespace atlas {
//! convert acl error code to error string //! convert acl error code to error string
@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, aclrtStream stream) { unsigned int flags, aclrtStream stream) {
megcore::AtlasContext ctx{stream}; megcore::AtlasContext ctx{stream};
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithAtlasContext(
compHandle, devHandle, flags, ctx);
} }


inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle,
aclrtStream* stream) {
inline megcoreStatus_t megcoreGetACLStream(
megcoreComputingHandle_t handle, aclrtStream* stream) {
megcore::AtlasContext ctx; megcore::AtlasContext ctx;
auto ret = megcore::getAtlasContext(handle, &ctx); auto ret = megcore::getAtlasContext(handle, &ctx);
*stream = ctx.stream; *stream = ctx.stream;


+ 2
- 3
dnn/include/megcore_cambricon.h View File

@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx); unsigned int flags, const CambriconContext& ctx);


megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle,
CambriconContext* ctx);
megcoreStatus_t getCambriconContext(
megcoreComputingHandle_t handle, CambriconContext* ctx);


} // namespace megcore } // namespace megcore


@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue(
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen


+ 1
- 2
dnn/include/megcore_cdefs.h View File

@@ -40,7 +40,6 @@ typedef enum {
megcoreErrorInternalError = 5, megcoreErrorInternalError = 5,
} megcoreStatus_t; } megcoreStatus_t;



/** /**
* \brief Memcpy kind * \brief Memcpy kind
*/ */
@@ -70,6 +69,6 @@ struct AsyncErrorInfo {
char msg[228]; char msg[228];
int msg_args[4]; int msg_args[4];
}; };
} // namespace megcore
} // namespace megcore


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 4
dnn/include/megcore_cuda.h View File

@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CudaContext& ctx); unsigned int flags, const CudaContext& ctx);


megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle,
CudaContext* ctx);
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx);


} // namespace megcore } // namespace megcore


@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream(
unsigned int flags, cudaStream_t stream) { unsigned int flags, cudaStream_t stream) {
megcore::CudaContext ctx; megcore::CudaContext ctx;
ctx.stream = stream; ctx.stream = stream;
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithCUDAContext(
compHandle, devHandle, flags, ctx);
} }


static inline megcoreStatus_t megcoreGetCUDAStream( static inline megcoreStatus_t megcoreGetCUDAStream(


+ 6
- 5
dnn/include/megcore_rocm.h View File

@@ -23,7 +23,9 @@ struct ROCMContext {
hipStream_t stream = nullptr; hipStream_t stream = nullptr;


static std::atomic_bool sm_miopen_algo_search; static std::atomic_bool sm_miopen_algo_search;
static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); }
static inline bool enable_miopen_algo_search() {
return sm_miopen_algo_search.load();
}
static inline void enable_miopen_algo_search(bool enable_algo_search) { static inline void enable_miopen_algo_search(bool enable_algo_search) {
sm_miopen_algo_search.store(enable_algo_search); sm_miopen_algo_search.store(enable_algo_search);
} }
@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const ROCMContext& ctx); unsigned int flags, const ROCMContext& ctx);


megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle,
ROCMContext* ctx);
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx);


// Set MIOpen algo search enabled or disabled // Set MIOpen algo search enabled or disabled
megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true);
@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream(
unsigned int flags, hipStream_t stream) { unsigned int flags, hipStream_t stream) {
megcore::ROCMContext ctx; megcore::ROCMContext ctx;
ctx.stream = stream; ctx.stream = stream;
return megcore::createComputingHandleWithROCMContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithROCMContext(
compHandle, devHandle, flags, ctx);
} }


static inline megcoreStatus_t megcoreGetROCMStream( static inline megcoreStatus_t megcoreGetROCMStream(


+ 1
- 1
dnn/include/megdnn.h View File

@@ -10,7 +10,7 @@
*/ */
#pragma once #pragma once


#include "megdnn/version.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "megdnn/version.h"


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 73
- 74
dnn/include/megdnn/arch.h View File

@@ -14,20 +14,20 @@
#include "megdnn/config/config.h" #include "megdnn/config/config.h"


#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#if !defined (__clang__)
// gcc specific
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#if GCC_VERSION < 40800
#error "GCC version should be at least 4.8.0."
#endif // GCC_VERSION < 40800
#endif // !defined(__clang__)
#if !defined(__clang__)
// gcc specific
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#if GCC_VERSION < 40800
#error "GCC version should be at least 4.8.0."
#endif // GCC_VERSION < 40800
#endif // !defined(__clang__)


#ifndef megdnn_trap
#define megdnn_trap() __builtin_trap()
#endif
#ifndef megdnn_trap
#define megdnn_trap() __builtin_trap()
#endif


#define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0)
#define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0)


#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG)
//! Thumb2 limit code length //! Thumb2 limit code length
@@ -36,123 +36,122 @@
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__))
#endif #endif


#define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_NORETURN __attribute__((noreturn))
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#if defined(__clang_major__) && (__clang_major__ >= 7)
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#else
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]]
#endif
#define MEGDNN_NOINLINE __attribute__((noinline))
#define megdnn_isatty(x) isatty(x)
#define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_NORETURN __attribute__((noreturn))
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#if defined(__clang_major__) && (__clang_major__ >= 7)
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#else
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]]
#endif
#define MEGDNN_NOINLINE __attribute__((noinline))
#define megdnn_isatty(x) isatty(x)
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) #elif defined(__INTEL_COMPILER) || defined(_MSC_VER)


#ifndef megdnn_trap #ifndef megdnn_trap
#define megdnn_trap() __debugbreak() #define megdnn_trap() __debugbreak()
#endif #endif


#define megdnn_likely(v) (bool(v))
#define megdnn_likely(v) (bool(v))
#define megdnn_unlikely(v) (bool(v)) #define megdnn_unlikely(v) (bool(v))


#define MEGDNN_DEPRECATED #define MEGDNN_DEPRECATED
#define MEGDNN_PACKED #define MEGDNN_PACKED
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert #define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_FINAL final


#if defined(_MSC_VER) #if defined(_MSC_VER)
#define MEGDNN_NORETURN __declspec(noreturn)
#define MEGDNN_NOINLINE __declspec(noinline)
#define MEGDNN_NORETURN __declspec(noreturn)
#define MEGDNN_NOINLINE __declspec(noinline)
#else #else
#define MEGDNN_NORETURN
#define MEGDNN_FORCE_NOINLINE
#endif // _MSC_VER
#define MEGDNN_NORETURN
#define MEGDNN_FORCE_NOINLINE
#endif // _MSC_VER


#define MEGDNN_WARN_UNUSED_RESULT #define MEGDNN_WARN_UNUSED_RESULT


#define megdnn_isatty(x) _isatty(x) #define megdnn_isatty(x) _isatty(x)


#else #else
#error "unknown compiler"
#endif // __GNUC__
#error "unknown compiler"
#endif // __GNUC__


// __cpp_exceptions and __cpp_rtti is referred from // __cpp_exceptions and __cpp_rtti is referred from
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// similar for __GXX_RTTI // similar for __GXX_RTTI
// _CPPUNWIND and _CPPRTTI is used by MSVC, see // _CPPUNWIND and _CPPRTTI is used by MSVC, see
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019
#ifndef MEGDNN_ENABLE_EXCEPTIONS #ifndef MEGDNN_ENABLE_EXCEPTIONS
#if __cpp_exceptions || __EXCEPTIONS || \
(defined(_MSC_VER) && defined(_CPPUNWIND))
#define MEGDNN_ENABLE_EXCEPTIONS 1
#else
#define MEGDNN_ENABLE_EXCEPTIONS 0
#endif
#if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND))
#define MEGDNN_ENABLE_EXCEPTIONS 1
#else
#define MEGDNN_ENABLE_EXCEPTIONS 0
#endif
#endif #endif
#ifndef MEGDNN_ENABLE_RTTI #ifndef MEGDNN_ENABLE_RTTI
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI))
#define MEGDNN_ENABLE_RTTI 1
#else
#define MEGDNN_ENABLE_RTTI 0
#endif
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI))
#define MEGDNN_ENABLE_RTTI 1
#else
#define MEGDNN_ENABLE_RTTI 0
#endif
#endif #endif


#ifdef __CUDACC__ #ifdef __CUDACC__
#define MEGDNN_CC_CUDA 1
#undef MEGDNN_CONSTEXPR
#define MEGDNN_CONSTEXPR const
#define MEGDNN_CC_CUDA 1
#undef MEGDNN_CONSTEXPR
#define MEGDNN_CONSTEXPR const


#if defined(__CUDACC_VER_MAJOR__) #if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ >= 9 #if __CUDACC_VER_MAJOR__ >= 9
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg);
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg);
#else #else
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg)
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg)
#endif #endif
#endif #endif


#define nullptr NULL
#undef MEGDNN_FINAL
#define MEGDNN_FINAL
#define nullptr NULL
#undef MEGDNN_FINAL
#define MEGDNN_FINAL
#elif defined(__HIPCC__) #elif defined(__HIPCC__)
#define MEGDNN_CC_CUDA 1
#define MEGDNN_CC_CUDA 1
#else #else
#define MEGDNN_CC_HOST 1
#endif // __CUDACC__
#define MEGDNN_CC_HOST 1
#endif // __CUDACC__


// MEGDNN_HOST and MEGDNN_DEVICE // MEGDNN_HOST and MEGDNN_DEVICE
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
#define MEGDNN_HOST __host__
#define MEGDNN_DEVICE __device__
#define MEGDNN_HOST __host__
#define MEGDNN_DEVICE __device__
#else #else
#define MEGDNN_HOST
#define MEGDNN_DEVICE
#define MEGDNN_HOST
#define MEGDNN_DEVICE
#endif #endif


#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
#define MEGDNN_FORCE_INLINE __forceinline__
#define MEGDNN_FORCE_INLINE __forceinline__
#else #else
#if __GNUC__ || __has_attribute(always_inline) #if __GNUC__ || __has_attribute(always_inline)
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#else #else
#define MEGDNN_FORCE_INLINE inline
#define MEGDNN_FORCE_INLINE inline
#endif #endif
#endif #endif


#if defined(_MSC_VER) || defined(WIN32) #if defined(_MSC_VER) || defined(WIN32)
#define ATTR_ALIGNED(v) __declspec(align(v))
#define ATTR_ALIGNED(v) __declspec(align(v))
#else #else
#define ATTR_ALIGNED(v) __attribute__((aligned(v)))
#define ATTR_ALIGNED(v) __attribute__((aligned(v)))
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 33
- 42
dnn/include/megdnn/basic_types.h View File

@@ -16,10 +16,10 @@
#include "megdnn/internal/defs.h" #include "megdnn/internal/defs.h"


#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
#include <cstdarg>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <cstdarg>
#include "megdnn/thin/small_vector.h" #include "megdnn/thin/small_vector.h"
#endif // MEGDNN_CC_HOST #endif // MEGDNN_CC_HOST


@@ -35,8 +35,7 @@ class ErrorHandler {
protected: protected:
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0;


MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(
const std::string& msg) {
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) {
on_megdnn_error(msg); on_megdnn_error(msg);
} }


@@ -70,8 +69,9 @@ public:
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
enum class LogLevel { DEBUG, INFO, WARN, ERROR }; enum class LogLevel { DEBUG, INFO, WARN, ERROR };


typedef void (*LogHandler)(LogLevel level, const char* file, const char* func,
int line, const char* fmt, va_list ap);
typedef void (*LogHandler)(
LogLevel level, const char* file, const char* func, int line, const char* fmt,
va_list ap);


/*! /*!
* \brief set the callback to receive all log messages * \brief set the callback to receive all log messages
@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape {
ptrdiff_t low_elem, low_byte; ptrdiff_t low_elem, low_byte;
size_t high_elem, high_byte; size_t high_elem, high_byte;


Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem,
size_t high_byte)
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte)
: low_elem(low_elem), : low_elem(low_elem),
low_byte(low_byte), low_byte(low_byte),
high_elem(high_elem), high_elem(high_elem),
@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape {
TensorLayout(const TensorShape& shape, DType dtype, Format format); TensorLayout(const TensorShape& shape, DType dtype, Format format);


//! creating layout with user-specified shape and stride. //! creating layout with user-specified shape and stride.
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype);
TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype);


TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype, Format format);
TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype,
Format format);


/* =================== inplace modifiers =================== */ /* =================== inplace modifiers =================== */


@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape {
* *
* \throw TensorReshapeError if no stride exists for target shape. * \throw TensorReshapeError if no stride exists for target shape.
*/ */
TensorLayout reshape(const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;
TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief try to reshape to another view; return whether these two shapes * \brief try to reshape to another view; return whether these two shapes
@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape {
* \return true iff there exists target stride so this layout can be * \return true iff there exists target stride so this layout can be
* converted to target shape and the elements can match. * converted to target shape and the elements can match.
*/ */
bool try_reshape(TensorLayout& output,
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
bool try_reshape(TensorLayout& output, const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief Broadcast on dims with shape == 1 to match target *shape*. * \brief Broadcast on dims with shape == 1 to match target *shape*.
* \throw TensorReshapeError if could not be satisfied * \throw TensorReshapeError if could not be satisfied
*/ */
TensorLayout broadcast(const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;
TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;


/*! /*!
* \brief Collapse consecutive axes with contiguous layout together * \brief Collapse consecutive axes with contiguous layout together
@@ -441,8 +440,7 @@ struct Workspace {


Workspace() : raw_ptr(NULL), size(0) {} Workspace() : raw_ptr(NULL), size(0) {}


Workspace(dt_byte* raw_ptr_, size_t size_)
: raw_ptr(raw_ptr_), size(size_) {}
Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {}


template <typename T> template <typename T>
T* ptr(size_t offset_in_bytes = 0) const { T* ptr(size_t offset_in_bytes = 0) const {
@@ -467,9 +465,8 @@ public:
* \param shape requested output shape * \param shape requested output shape
* \param user_data extra user data passed in DynOutMallocPolicyCall * \param user_data extra user data passed in DynOutMallocPolicyCall
*/ */
virtual TensorND alloc_output(size_t id, DType dtype,
const TensorShape& shape,
void* user_data) = 0;
virtual TensorND alloc_output(
size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0;


/*! /*!
* \brief allocate workspace memory * \brief allocate workspace memory
@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall {
*/ */
template <typename T = void, typename elem = T> template <typename T = void, typename elem = T>
T* alloc_workspace(size_t nr_elem) { T* alloc_workspace(size_t nr_elem) {
using real_elem =
typename std::conditional<std::is_same<elem, void>::value,
uint8_t, elem>::type;
return static_cast<T*>(policy->alloc_workspace(
nr_elem * sizeof(real_elem), user_data));
using real_elem = typename std::conditional<
std::is_same<elem, void>::value, uint8_t, elem>::type;
return static_cast<T*>(
policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data));
} }


void free_workspace(void* ptr) {
return policy->free_workspace(ptr, user_data);
}
void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); }
}; };



template <typename T> template <typename T>
class EnumClassBit { class EnumClassBit {
std::underlying_type_t<T> m_val; std::underlying_type_t<T> m_val;
@@ -528,8 +521,7 @@ class EnumClassBit {
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}


public: public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {}


constexpr operator T() const { return static_cast<T>(m_val); } constexpr operator T() const { return static_cast<T>(m_val); }


@@ -542,7 +534,7 @@ public:


DEF_OPR(&) DEF_OPR(&)
DEF_OPR(|) DEF_OPR(|)
DEF_OPR (^)
DEF_OPR(^)


constexpr EnumClassBit operator~() const { return ~m_val; } constexpr EnumClassBit operator~() const { return ~m_val; }


@@ -553,14 +545,13 @@ public:


} // namespace megdnn } // namespace megdnn


#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
} }


#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \


+ 6
- 7
dnn/include/megdnn/common.h View File

@@ -14,14 +14,14 @@
#include "megbrain_build_config.h" #include "megbrain_build_config.h"


#if MGB_ENABLE_GETENV #if MGB_ENABLE_GETENV
#define MGB_GETENV ::std::getenv
#define MGB_GETENV ::std::getenv
#else #else
#define MGB_GETENV(_name) static_cast<char*>(nullptr)
#define MGB_GETENV(_name) static_cast<char*>(nullptr)
#endif #endif


#ifdef WIN32 #ifdef WIN32
#define unsetenv(_name) _putenv_s(_name, "");
#define setenv(name,value,overwrite) _putenv_s(name,value)
#define unsetenv(_name) _putenv_s(_name, "");
#define setenv(name, value, overwrite) _putenv_s(name, value)
#endif #endif


namespace megdnn { namespace megdnn {
@@ -32,8 +32,7 @@ namespace megdnn {
*/ */
template <class Opr, typename... Args> template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) { bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args(
opr, std::forward<Args>(args)...);
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) { for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) { if (i->is_available(size_args)) {
return true; return true;
@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) {
return false; return false;
} }


}
} // namespace megdnn


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 4
dnn/include/megdnn/cuda.h View File

@@ -17,11 +17,11 @@
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {


std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream,
int device_id = -1);
cudaStream_t get_cuda_stream(Handle *handle);
std::unique_ptr<Handle> make_cuda_handle_with_stream(
cudaStream_t stream, int device_id = -1);
cudaStream_t get_cuda_stream(Handle* handle);


} // namespace megdnn
} // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 353
- 463
dnn/include/megdnn/dtype.h
File diff suppressed because it is too large
View File


+ 18
- 13
dnn/include/megdnn/dtype/half_common_epilogue.h View File

@@ -3,17 +3,22 @@
* *
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* Permission is hereby granted, free of charge, to any person obtaining a copy of this
* software and associated documentation files (the "Software"), to deal in the Software
* without restriction, including without limitation the rights to use, copy, modify,
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be included in all copies
* or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* *
* Version 1.11.0 * Version 1.11.0
* \file * \file
@@ -41,8 +46,8 @@
#undef HALF_NOEXCEPT #undef HALF_NOEXCEPT
#undef HALF_NOTHROW #undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS #ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif #endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 147
- 144
dnn/include/megdnn/dtype/half_common_prologue.h View File

@@ -3,17 +3,22 @@
* *
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
* Permission is hereby granted, free of charge, to any person obtaining a copy of this
* software and associated documentation files (the "Software"), to deal in the Software
* without restriction, including without limitation the rights to use, copy, modify,
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
* *
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
* The above copyright notice and this permission notice shall be included in all copies
* or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* *
* Version 1.11.0 * Version 1.11.0
* \file * \file
@@ -39,166 +44,164 @@
#include "megdnn/arch.h" #include "megdnn/arch.h"


/// Combined gcc version number. /// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)
#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)


//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
// check C++11 language features
#if defined(__clang__) // clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \
!defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER)
//Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >=
1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define
HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 &&
!defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define
HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 &&
!defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define
HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/
#elif defined(__GNUC__) // gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) // Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if,
// negative unsigned
#endif #endif


//check C++11 library features
// check C++11 library features
#include <utility> #include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#if defined(_LIBCPP_VERSION) // libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) // libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif #endif
#undef HALF_GNUC_VERSION #undef HALF_GNUC_VERSION


//support constexpr
// support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR #if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else #else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif #endif


//support noexcept
// support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT #if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else #else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif #endif


#include <algorithm> #include <algorithm>
#include <limits>
#include <climits> #include <climits>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <ostream>
#include <istream> #include <istream>
#include <limits>
#include <ostream>
#if HALF_ENABLE_CPP11_TYPE_TRAITS #if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#include <type_traits>
#endif #endif
#if HALF_ENABLE_CPP11_CSTDINT #if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#include <cstdint>
#endif #endif
#if HALF_ENABLE_CPP11_HASH #if HALF_ENABLE_CPP11_HASH
#include <functional>
#include <functional>
#endif #endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 134
- 137
dnn/include/megdnn/handle.h View File

@@ -12,8 +12,8 @@
#pragma once #pragma once


#include "megcore.h" #include "megcore.h"
#include "megdnn/config/config.h"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/config/config.h"


#include <functional> #include <functional>
#include <memory> #include <memory>
@@ -24,150 +24,147 @@ namespace megdnn {
class OperatorBase; class OperatorBase;


class Handle { class Handle {
public:
enum class HandleType {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
ROCM = 11,
ATLAS = 13,
CAMBRICON = 12,
};

//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};

protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type);

public:
/**
* \brief Create a MegDNN handle from a MegCore Computing handle.
*
* \param[in] computing_handle MegCore computing handle. Please note
* that computing_handle would not be released when this Handle is
* destructed
* \param[in] debug_level
* Applicable for CPU computing handle.
* 0 means taking the fastest possible code path; it may contains
* platform-specific instructions such as SSE for x86_64 or NEON for
* armv7v7.
* 1 means taking the fastest possible code path without
* platform-specific instructions in C++ code. Note that the compiled
* binary file still contains platform-specific codes.
* 2 means taking the naive code path. Performance is severely
* hampered, but it is less error-prone since the internal
* implementation is rather straightforward.
*
* **Debug level 1 and 2 should not be used in productions.**
*/
static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle,
int debug_level = 0);
public:
enum class HandleType {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
ROCM = 11,
ATLAS = 13,
CAMBRICON = 12,
};

//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};

protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type);

public:
/**
* \brief Create a MegDNN handle from a MegCore Computing handle.
*
* \param[in] computing_handle MegCore computing handle. Please note
* that computing_handle would not be released when this Handle is
* destructed
* \param[in] debug_level
* Applicable for CPU computing handle.
* 0 means taking the fastest possible code path; it may contains
* platform-specific instructions such as SSE for x86_64 or NEON for
* armv7v7.
* 1 means taking the fastest possible code path without
* platform-specific instructions in C++ code. Note that the compiled
* binary file still contains platform-specific codes.
* 2 means taking the naive code path. Performance is severely
* hampered, but it is less error-prone since the internal
* implementation is rather straightforward.
*
* **Debug level 1 and 2 should not be used in productions.**
*/
static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle, int debug_level = 0);


#if MEGDNN_WITH_CUDA #if MEGDNN_WITH_CUDA
static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_cuda_operator();
static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_cuda_operator();
#endif #endif
#if MEGDNN_WITH_ROCM #if MEGDNN_WITH_ROCM
static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_rocm_operator();
static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_rocm_operator();
#endif #endif


virtual ~Handle();

/*!
* \brief Get the underlying megcore computing handle.
*/
megcoreComputingHandle_t megcore_computing_handle() const {
return m_computing_handle;
}

/*!
* \brief set a callback function to be invoked when this handle is
* destructed, so associated resources can be released (e.g.
* computing handle)
*
* This function can be called at most once.
*/
void set_destructor(const thin_function<void()> &d);

/*!
* \brief set a callback to be invoked when an operator is destructed
* \param[in,out] cb the callback function; it would be set to the
* previous callback function
*/
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) {
cb.swap(m_on_opr_destructed);
}

void on_opr_destructed(OperatorBase* opr);

/**
* \brief Create operator of Opr type.
*/
template <typename Opr>
std::unique_ptr<Opr> create_operator();

/*
* =============================================================
* Users should call functions below to query memory requirement.
* =============================================================
*/

/**
* \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes.
*/
virtual size_t alignment_requirement() const;

//! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;

//! get vendor type
virtual HandleVendorType vendor_type() const;

HandleType type() const {
return m_handle_type;
}

/**
* \brief Check is the layout satisfy cross device copy constraint.
* 1. The handle of the src and the dst is the same kind
* 2. The dst is continguous.
*/
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src);

private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
volatile uint32_t m_alive_magic = ALIVE_MAGIC;
megcoreComputingHandle_t m_computing_handle;
const HandleType m_handle_type;
thin_function<void()> m_destructor;
thin_function<void(OperatorBase*)> m_on_opr_destructed;

Handle() = delete;
Handle(const Handle &rhs) = delete;
Handle &operator=(const Handle &rhs) = delete;
virtual ~Handle();

/*!
* \brief Get the underlying megcore computing handle.
*/
megcoreComputingHandle_t megcore_computing_handle() const {
return m_computing_handle;
}

/*!
* \brief set a callback function to be invoked when this handle is
* destructed, so associated resources can be released (e.g.
* computing handle)
*
* This function can be called at most once.
*/
void set_destructor(const thin_function<void()>& d);

/*!
* \brief set a callback to be invoked when an operator is destructed
* \param[in,out] cb the callback function; it would be set to the
* previous callback function
*/
void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) {
cb.swap(m_on_opr_destructed);
}

void on_opr_destructed(OperatorBase* opr);

/**
* \brief Create operator of Opr type.
*/
template <typename Opr>
std::unique_ptr<Opr> create_operator();

/*
* =============================================================
* Users should call functions below to query memory requirement.
* =============================================================
*/

/**
* \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes.
*/
virtual size_t alignment_requirement() const;

//! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;

//! get vendor type
virtual HandleVendorType vendor_type() const;

HandleType type() const { return m_handle_type; }

/**
* \brief Check is the layout satisfy cross device copy constraint.
* 1. The handle of the src and the dst is the same kind
* 2. The dst is continguous.
*/
virtual bool check_cross_dev_copy_constraint(const TensorLayout& src);

private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
volatile uint32_t m_alive_magic = ALIVE_MAGIC;
megcoreComputingHandle_t m_computing_handle;
const HandleType m_handle_type;
thin_function<void()> m_destructor;
thin_function<void(OperatorBase*)> m_on_opr_destructed;

Handle() = delete;
Handle(const Handle& rhs) = delete;
Handle& operator=(const Handle& rhs) = delete;
}; };


} // namespace megdnn
} // namespace megdnn


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"




+ 3
- 2
dnn/include/megdnn/heuristic_cache.h View File

@@ -49,8 +49,9 @@ public:
mutable std::string m_input; mutable std::string m_input;


public: public:
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr,
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0)
Key(Handle* opr_handle, Algorithm::OprType opr_type,
const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size,
const void* param_ptr = nullptr, size_t param_size = 0)
: m_handle{opr_handle}, : m_handle{opr_handle},
m_opr_type{static_cast<uint32_t>(opr_type)}, m_opr_type{static_cast<uint32_t>(opr_type)},
m_inp_layouts_ptr{inp_layouts_ptr}, m_inp_layouts_ptr{inp_layouts_ptr},


+ 5
- 6
dnn/include/megdnn/internal/defs.h View File

@@ -16,20 +16,19 @@
* \brief iterate through small (usually used) ndim values * \brief iterate through small (usually used) ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__)
cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__)


/*! /*!
* \brief iterate through large (rarely used) ndim values * \brief iterate through large (rarely used) ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \
cb(7, ##__VA_ARGS__)
cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__)


/*! /*!
* \brief iterate through all ndim values * \brief iterate through all ndim values
*/ */
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__)
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__)


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 20
- 19
dnn/include/megdnn/internal/opr_header_prologue.h View File

@@ -11,14 +11,14 @@
// intentional no header guard here // intentional no header guard here


#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/oprs/base.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/opr_result_defs.h" #include "megdnn/opr_result_defs.h"
#include "megdnn/oprs/base.h"


#include "./visibility_prologue.h" #include "./visibility_prologue.h"


#include <limits>
#include <array> #include <array>
#include <limits>


#ifndef _megdnn_in #ifndef _megdnn_in
#define _megdnn_in #define _megdnn_in
@@ -29,36 +29,37 @@
#endif #endif


#ifndef _megdnn_tensor_in #ifndef _megdnn_tensor_in
#define _megdnn_tensor_in const TensorND &
#define _megdnn_tensor_in const TensorND&
#endif #endif


#ifndef _megdnn_tensor_out #ifndef _megdnn_tensor_out
#define _megdnn_tensor_out const TensorND &
#define _megdnn_tensor_out const TensorND&
#endif #endif


#ifndef _megdnn_tensor_inout #ifndef _megdnn_tensor_inout
#define _megdnn_tensor_inout const TensorND &
#define _megdnn_tensor_inout const TensorND&
#endif #endif


#ifndef _megdnn_workspace #ifndef _megdnn_workspace
#define _megdnn_workspace const Workspace &
#define _megdnn_workspace const Workspace&
#endif #endif


#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
public: \
_opr_name(Handle *handle): _base_name(handle) {} \
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
public: \
_opr_name(Handle* handle) : _base_name(handle) {}


#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs;


#define DEF_OPR_PARAM(_pname) \
public: \
using Param = param::_pname; \
Param& param() { return m_param; } \
const Param& param() const { return m_param; } \
protected: \
Param m_param
#define DEF_OPR_PARAM(_pname) \
public: \
using Param = param::_pname; \
Param& param() { return m_param; } \
const Param& param() const { return m_param; } \
\
protected: \
Param m_param


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 1
dnn/include/megdnn/internal/visibility_epilogue.h View File

@@ -20,4 +20,3 @@
#endif #endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen


+ 16
- 20
dnn/include/megdnn/opr_result_defs.h View File

@@ -16,25 +16,21 @@
namespace megdnn { namespace megdnn {
namespace opr_result { namespace opr_result {


struct Checksum {
uint32_t checksum;
union {
int32_t iv;
float fv;
} last_val;

bool operator == (const Checksum &rhs) const {
return checksum == rhs.checksum &&
last_val.iv == rhs.last_val.iv;
}

bool operator != (const Checksum &rhs) const {
return !operator==(rhs);
}
};

} // namespace opr_result
} // namespace megdnn

struct Checksum {
uint32_t checksum;
union {
int32_t iv;
float fv;
} last_val;

bool operator==(const Checksum& rhs) const {
return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv;
}

bool operator!=(const Checksum& rhs) const { return !operator==(rhs); }
};

} // namespace opr_result
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 4
dnn/include/megdnn/oprs.h View File

@@ -12,11 +12,11 @@


#include "megdnn/oprs/cv.h" #include "megdnn/oprs/cv.h"
#include "megdnn/oprs/general.h" #include "megdnn/oprs/general.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn_int.h" #include "megdnn/oprs/nn_int.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/utils.h" #include "megdnn/oprs/utils.h"
#include "megdnn/oprs/linalg.h"


template <typename Opr> template <typename Opr>
struct OprArityTrait; struct OprArityTrait;
@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1);


#undef INST_ARITY #undef INST_ARITY




// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 91
- 124
dnn/include/megdnn/oprs/base.h View File

@@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t {
INT8X8X16 = 1 << 4, INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5, INT16X16X32 = 1 << 5,
INT4X4X16 = 1 << 6, INT4X4X16 = 1 << 6,
QINT4x4x32 = 1 << 7,
QINT4x4x32 = 1 << 7,
}; };


/*! /*!
@@ -195,16 +195,16 @@ public:
Handle::HandleType handle_type() const { return m_handle_type; } Handle::HandleType handle_type() const { return m_handle_type; }


Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } Info::Desc desc() const { return {handle_type(), type(), param(), name()}; }
Info info() const {
return {desc(), attribute()};
}
Info info() const { return {desc(), attribute()}; }


template <typename T> template <typename T>
static void serialize_write_pod(const T& val, std::string& result) { static void serialize_write_pod(const T& val, std::string& result) {
static_assert(std::is_trivially_copyable<T>::value,
"type should be trivially copyable");
static_assert(!std::is_pointer<T>::value,
"serialize pointer is unsafe in eager execution mode");
static_assert(
std::is_trivially_copyable<T>::value,
"type should be trivially copyable");
static_assert(
!std::is_pointer<T>::value,
"serialize pointer is unsafe in eager execution mode");
result.append(reinterpret_cast<const char*>(&val), sizeof(T)); result.append(reinterpret_cast<const char*>(&val), sizeof(T));
} }


@@ -231,9 +231,8 @@ public:
return ret; return ret;
} }


static std::string deserialize_read_pod(const std::string& data,
size_t offset = 0,
size_t size = 0) {
static std::string deserialize_read_pod(
const std::string& data, size_t offset = 0, size_t size = 0) {
return std::string(data.data() + offset, size); return std::string(data.data() + offset, size);
} }


@@ -286,8 +285,8 @@ public:
* \param layouts origin layouts of the parent opr * \param layouts origin layouts of the parent opr
* \param opr parent opr * \param opr parent opr
*/ */
virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&,
const OperatorBase*) const {
virtual std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray&, const OperatorBase*) const {
return {}; return {};
} }


@@ -333,9 +332,7 @@ public:


ExecutionPolicy& execution_policy() { return m_execution_policy; } ExecutionPolicy& execution_policy() { return m_execution_policy; }


const ExecutionPolicy& execution_policy() const {
return m_execution_policy;
}
const ExecutionPolicy& execution_policy() const { return m_execution_policy; }


virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0;


@@ -355,8 +352,8 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute; using AlgoAttribute = detail::Algorithm::Attribute;


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1)) { for (auto&& algo : get_all_algorithms(p0, p1)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -364,8 +361,8 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1)) { for (auto&& algo : get_all_algorithms_safe(p0, p1)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -382,12 +379,11 @@ public:
*/ */
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr)
->info(); ->info();
} }


@@ -408,8 +404,7 @@ protected:
*/ */
virtual Algorithm* get_algorithm_heuristic( virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p0, const TensorLayout& p1,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
}; };
@@ -423,9 +418,8 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute; using AlgoAttribute = detail::Algorithm::Attribute;


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2)) { for (auto&& algo : get_all_algorithms(p0, p1, p2)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -433,9 +427,8 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -451,14 +444,13 @@ public:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info(); ->info();
} }


@@ -467,11 +459,9 @@ protected:


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe( virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -480,10 +470,8 @@ protected:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
virtual Algorithm* get_algorithm_heuristic( virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
}; };
@@ -497,10 +485,9 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute; using AlgoAttribute = detail::Algorithm::Attribute;


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -508,10 +495,9 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -527,14 +513,14 @@ public:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info(); ->info();
} }


@@ -543,11 +529,11 @@ protected:


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe( virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -556,10 +542,9 @@ protected:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
virtual Algorithm* get_algorithm_heuristic( virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
}; };
@@ -573,11 +558,9 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute; using AlgoAttribute = detail::Algorithm::Attribute;


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -585,11 +568,9 @@ public:
return ret; return ret;
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info()); ret.emplace_back(algo->info());
@@ -605,16 +586,14 @@ public:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, positive_attr,
negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info(); ->info();
} }


@@ -622,14 +601,12 @@ protected:
~MultiAlgoOpr() = default; ~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0;
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe( virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) = 0;


/** /**
* \brief Returns the best algorithm by heuristic. * \brief Returns the best algorithm by heuristic.
@@ -638,11 +615,9 @@ protected:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
virtual Algorithm* get_algorithm_heuristic( virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
}; };
@@ -657,9 +632,8 @@ public:


//! get all possible algorithm decriptions for the specified layouts //! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info( std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) { const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) {
@@ -669,9 +643,8 @@ public:
} }


std::vector<AlgorithmInfo> get_all_algorithms_info_safe( std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) { const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret; std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) {
@@ -687,17 +660,15 @@ public:
* The selected algorithm should not use workspace more than * The selected algorithm should not use workspace more than
*/ */
AlgorithmInfo get_algorithm_info_heuristic( AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7, const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, positive_attr,
negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes,
positive_attr, negative_attr)
->info(); ->info();
} }


@@ -705,15 +676,13 @@ protected:
~MultiAlgoOpr() = default; ~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) = 0; const TensorLayout& p6, const TensorLayout& p7) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe( virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) = 0; const TensorLayout& p6, const TensorLayout& p7) = 0;


/** /**
@@ -723,12 +692,10 @@ protected:
* \p workspace_limit_in_bytes. * \p workspace_limit_in_bytes.
*/ */
virtual Algorithm* get_algorithm_heuristic( virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7, const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
}; };


+ 123
- 99
dnn/include/megdnn/oprs/cv.h View File

@@ -31,15 +31,17 @@ class FlipForward : public FlipBase {
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Flip = FlipForward; using Flip = FlipForward;


@@ -56,15 +58,17 @@ class RotateForward : public RotateBase {
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Rotate = RotateForward; using Rotate = RotateForward;


@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase {
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using ROICopy = ROICopyForward; using ROICopy = ROICopyForward;


@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase {
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using CvtColor = CvtColorForward; using CvtColor = CvtColorForward;


@@ -130,8 +138,9 @@ public:
using BorderMode = Param::BorderMode; using BorderMode = Param::BorderMode;


protected: protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst);
std::string param_msg() const; std::string param_msg() const;
int get_real_coord(int p, int len); int get_real_coord(int p, int len);
}; };
@@ -148,15 +157,17 @@ public:
* \warning src, trans, border_value, dst should be contiguous * \warning src, trans, border_value, dst should be contiguous
* The size of trans is N * 2 * 3 * The size of trans is N * 2 * 3
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& trans,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using WarpAffine = WarpAffineForward; using WarpAffine = WarpAffineForward;


@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase {
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using GaussianBlur = GaussianBlurForward; using GaussianBlur = GaussianBlurForward;


@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase {
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;


virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using Resize = ResizeForward; using Resize = ResizeForward;


@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase {
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);


public: public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;


virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& mat) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& mat) = 0;


protected: protected:
void check_exec(const TensorLayout& diff, const TensorLayout& mat,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& diff, const TensorLayout& mat,
size_t workspace_in_bytes);
}; };


/** /**
@@ -251,29 +268,32 @@ public:
using BorderMode = Param::BorderMode; using BorderMode = Param::BorderMode;


protected: protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst);
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst);
void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
}; };


class RemapForward : public RemapBase { class RemapForward : public RemapBase {
DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;


void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy,
TensorLayout& dst);
void deduce_layout(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);


virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };
using Remap = RemapForward; using Remap = RemapForward;


@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase {
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);


public: public:
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;


virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad) = 0;


protected: protected:
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
}; };


class RemapBackwardMat : public RemapBase { class RemapBackwardMat : public RemapBase {
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;


virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
}; };


class SeparableFilterBase : public OperatorBase { class SeparableFilterBase : public OperatorBase {
@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase {
DEF_OPR_PARAM(SeparableFilter); DEF_OPR_PARAM(SeparableFilter);


protected: protected:
void deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y,
const TensorLayout& dst);
void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst);
}; };


class SeparableFilterForward : public SeparableFilterBase { class SeparableFilterForward : public SeparableFilterBase {
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter_x,
const TensorLayout& filter_y,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using SeparableFilter = SeparableFilterForward; using SeparableFilter = SeparableFilterForward;




+ 847
- 822
dnn/include/megdnn/oprs/general.h
File diff suppressed because it is too large
View File


+ 164
- 180
dnn/include/megdnn/oprs/imgproc.h View File

@@ -13,173 +13,162 @@


namespace megdnn { namespace megdnn {


class WarpPerspectiveBase: public OperatorBase {
class WarpPerspectiveBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
DEF_OPR_PARAM(WarpPerspective); DEF_OPR_PARAM(WarpPerspective);
public:
using InterpolationMode = Param::InterpolationMode;
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
const TensorLayout &dst) {
check_layout_fwd(src, mat, {}, dst);
}

void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
const TensorLayout &mat_idx, const TensorLayout &dst);
std::string param_msg() const;
int get_real_coord(int p, int len);

public:
using InterpolationMode = Param::InterpolationMode;
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
check_layout_fwd(src, mat, {}, dst);
}

void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst);
std::string param_msg() const;
int get_real_coord(int p, int len);
}; };


class WarpPerspectiveForward: public WarpPerspectiveBase {
class WarpPerspectiveForward : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
public:
/**
* \param[in] src (n, channel, in_height, in_width)
* \param[in] mat (n, 3, 3)
* \param[out] dst (n, channel, out_height, out_width)
*
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
*
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
*
* src and dst can have different shapes, as long as their n and c agree.
* src, mat and dst should be contiguous.
*/
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec(src, mat, {}, dst, workspace);
}

/**
* \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* of [0, m-1].
*
* \param mat_idx the indices of input image that each matrix in \p mat
* should act on. It can also be empty and in such case \p mat
* should have the same batch size as \p src.
*/
virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &dst) {
return get_workspace_in_bytes(src, mat, {}, dst);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes);

void check_exec_allow_nhwc_mat_idx(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes);

public:
/**
* \param[in] src (n, channel, in_height, in_width)
* \param[in] mat (n, 3, 3)
* \param[out] dst (n, channel, out_height, out_width)
*
* \see
* http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
*
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
*
* src and dst can have different shapes, as long as their n and c agree.
* src, mat and dst should be contiguous.
*/
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec(src, mat, {}, dst, workspace);
}

/**
* \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* of [0, m-1].
*
* \param mat_idx the indices of input image that each matrix in \p mat
* should act on. It can also be empty and in such case \p mat
* should have the same batch size as \p src.
*/
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
return get_workspace_in_bytes(src, mat, {}, dst);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes);

void check_exec_allow_nhwc_mat_idx(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };
using WarpPerspective = WarpPerspectiveForward; using WarpPerspective = WarpPerspectiveForward;


class WarpPerspectiveBackwardData: public WarpPerspectiveBase {
class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
public:
/**
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. src
* \param[out] workspace temporary workspace to perform backward
*/
void exec(_megdnn_tensor_in mat,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(mat, {}, diff, grad, workspace);
}

virtual void exec(_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad) {
return get_workspace_in_bytes(mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad) = 0;
protected:
void check_exec(const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);

public:
/**
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. src
* \param[out] workspace temporary workspace to perform backward
*/
void exec(
_megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(mat, {}, diff, grad, workspace);
}

virtual void exec(
_megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& diff,
const TensorLayout& grad) {
return get_workspace_in_bytes(mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& mat_idx,
const TensorLayout& diff, const TensorLayout& grad) = 0;

protected:
void check_exec(
const TensorLayout& mat, const TensorLayout& mat_idx,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
}; };


class WarpPerspectiveBackwardMat: public WarpPerspectiveBase {
class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
public:
/**
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. mat
* \param[out] workspace temporary workspace to perform backward
*/
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(src, mat, {}, diff, grad, workspace);
}

virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad) {
return get_workspace_in_bytes(src, mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad) = 0;
protected:
void check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);

public:
/**
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. mat
* \param[out] workspace temporary workspace to perform backward
*/
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) {
exec(src, mat, {}, diff, grad, workspace);
}

virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
const TensorLayout& grad) {
return get_workspace_in_bytes(src, mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad) = 0;

protected:
void check_exec(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
}; };


class DctChannelSelectForward : public OperatorBase { class DctChannelSelectForward : public OperatorBase {
@@ -194,37 +183,32 @@ public:
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
* \param[out] workspace temporary workspace to perform forward * \param[out] workspace temporary workspace to perform forward
*/ */
virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

void deduce_layout(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

void deduce_layout(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, TensorLayout& dst);

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, const TensorLayout& dst) = 0;


protected: protected:
void check_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst);
void deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, const TensorLayout& dst);

void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, TensorLayout& dst);


std::string param_msg() const; std::string param_msg() const;
}; };


} // namespace megdnn
} // namespace megdnn


#include "megdnn/internal/opr_header_epilogue.h" #include "megdnn/internal/opr_header_epilogue.h"




+ 55
- 54
dnn/include/megdnn/oprs/linalg.h View File

@@ -33,22 +33,22 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType &C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;


static Algorithm::OprType get_opr_type() { static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
} }


protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using BatchedMatrixMul = BatchedMatrixMulForward; using BatchedMatrixMul = BatchedMatrixMulForward;


@@ -70,24 +70,24 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C); void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;


static size_t pack_size (const Param::Format format);
static size_t pack_size(const Param::Format format);


static Algorithm::OprType get_opr_type() { static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::MATRIX_MUL_FORWARD; return Algorithm::OprType::MATRIX_MUL_FORWARD;
} }


protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using MatrixMul = MatrixMulForward; using MatrixMul = MatrixMulForward;


@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayout& src, TensorLayout& dst);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst);
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);


protected: protected:
/*! /*!
@@ -116,8 +116,7 @@ protected:
* *
* Note that \p batch and \p n can be null * Note that \p batch and \p n can be null
*/ */
static void canonize_params(const TensorLayout& layout, size_t* batch,
size_t* n);
static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);


/*! /*!
* \brief canonize and validate input params for exec() impls * \brief canonize and validate input params for exec() impls
@@ -125,11 +124,12 @@ protected:
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
* be null * be null
*/ */
void check_exec(const TensorLayout& src, const TensorLayout& dst,
_megdnn_workspace workspace, size_t* batch, size_t* n);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
_megdnn_workspace workspace, size_t* batch, size_t* n);


virtual size_t get_workspace_in_bytes(size_t batch, size_t n,
size_t dtype_size) = 0;
virtual size_t get_workspace_in_bytes(
size_t batch, size_t n, size_t dtype_size) = 0;
}; };


//! inter-product of two vectors //! inter-product of two vectors
@@ -147,17 +147,17 @@ public:
* A, B, C must be contiguous. A and B must have the same 1-dimensional * A, B, C must be contiguous. A and B must have the same 1-dimensional
* shape and non-negative strides. C must be scalar. * shape and non-negative strides. C must be scalar.
*/ */
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;


protected: protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
}; };
using Dot = DotForward; using Dot = DotForward;


@@ -193,23 +193,24 @@ public:
* if compute_uv is false (default to true). * if compute_uv is false (default to true).
* *
*/ */
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u,
_megdnn_tensor_out s, _megdnn_tensor_out vt,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& u,
TensorLayout& s, TensorLayout& vt);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt);
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
_megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, TensorLayout& u, TensorLayout& s,
TensorLayout& vt);
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt);


protected: protected:
static void canonize_params(const TensorLayout& layout, size_t* batch,
size_t* m, size_t* n);
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n,
size_t dtype_size) = 0;
void check_exec(const TensorLayout& src, const TensorLayout& u,
const TensorLayout& s, const TensorLayout& vt,
size_t workspace_in_bytes);
static void canonize_params(
const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
virtual size_t get_workspace_in_bytes(
size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
void check_exec(
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt, size_t workspace_in_bytes);
}; };


using SVD = SVDForward; using SVD = SVDForward;


+ 629
- 618
dnn/include/megdnn/oprs/nn.h
File diff suppressed because it is too large
View File


+ 5
- 8
dnn/include/megdnn/oprs/nn_int.h View File

@@ -36,7 +36,7 @@ public:
struct ModeTrait { struct ModeTrait {
uint32_t arity = 0; //!< number of inputs needed uint32_t arity = 0; //!< number of inputs needed
CheckDtypeFunc check_inp[MAX_ARITY]; CheckDtypeFunc check_inp[MAX_ARITY];
SetOrCheckDtypeFunc check_out; //!< dtype of output var
SetOrCheckDtypeFunc check_out; //!< dtype of output var
bool need_specify_out_dtype = bool need_specify_out_dtype =
false; //!< the dtype should be setup externally, otherwise false; //!< the dtype should be setup externally, otherwise
//!< would be inferred by check_out(dtype, false) //!< would be inferred by check_out(dtype, false)
@@ -46,13 +46,10 @@ public:
static const ModeTrait& from_mode(Mode mode); static const ModeTrait& from_mode(Mode mode);
}; };


virtual void exec(_megdnn_in const TensorNDArray& src,
_megdnn_tensor_out dst) = 0;
virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;


//! get trait of current mode //! get trait of current mode
const ModeTrait& mode_trait() const {
return ModeTrait::from_mode(m_param.mode);
}
const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); }


//! deduce output layout //! deduce output layout
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst);
@@ -60,8 +57,8 @@ public:
protected: protected:
//! throw exception if incorrect layout; broadcast input shape to //! throw exception if incorrect layout; broadcast input shape to
//! output shape //! output shape
void check_layout_and_broadcast(const TensorLayoutPtrArray& src,
const TensorLayout& dst);
void check_layout_and_broadcast(
const TensorLayoutPtrArray& src, const TensorLayout& dst);
}; };


} // namespace megdnn } // namespace megdnn


+ 103
- 87
dnn/include/megdnn/oprs/utils.h View File

@@ -15,84 +15,97 @@
namespace megdnn { namespace megdnn {


//! base class for random number generators //! base class for random number generators
class RNGBase: public OperatorBase {
class RNGBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
public:
virtual void exec(_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
protected:
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0;

public:
virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;

protected:
virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
}; };


//! sample from poisson distribution //! sample from poisson distribution
class PoissonRNG: public OperatorBase {
class PoissonRNG : public OperatorBase {
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
DEF_OPR_PARAM(PoissonRNG); DEF_OPR_PARAM(PoissonRNG);
public:
virtual void exec(_megdnn_tensor_in lam,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &lam,
const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &lam, const TensorLayout &dst,
size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in lam, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& lam, const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& lam, const TensorLayout& dst,
size_t workspace_in_bytes);
}; };


//! sample from beta distribution //! sample from beta distribution
class BetaRNG: public OperatorBase {
class BetaRNG : public OperatorBase {
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(BetaRNG); DEF_OPR_PARAM(BetaRNG);
public:
virtual void exec(_megdnn_tensor_in alpha,
_megdnn_tensor_in beta,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha,
const TensorLayout &beta, const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &alpha, const TensorLayout &beta,
const TensorLayout &dst, size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& alpha, const TensorLayout& beta,
const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& alpha, const TensorLayout& beta,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };


//! sample from gamma distribution //! sample from gamma distribution
class GammaRNG: public OperatorBase {
class GammaRNG : public OperatorBase {
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(GammaRNG); DEF_OPR_PARAM(GammaRNG);
public:
virtual void exec(_megdnn_tensor_in shape,
_megdnn_tensor_in scale,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &shape,
const TensorLayout &scale, const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &shape, const TensorLayout &scale,
const TensorLayout &dst, size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& shape, const TensorLayout& scale,
const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& shape, const TensorLayout& scale,
const TensorLayout& dst, size_t workspace_in_bytes);
}; };


//! sample from uniform distribution on the interval (0, 1] //! sample from uniform distribution on the interval (0, 1]
class UniformRNG: public RNGBase {
class UniformRNG : public RNGBase {
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(UniformRNG); DEF_OPR_PARAM(UniformRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };


//! sample from gaussian distribution //! sample from gaussian distribution
class GaussianRNG: public RNGBase {
class GaussianRNG : public RNGBase {
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(GaussianRNG); DEF_OPR_PARAM(GaussianRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };


class PermutationRNG: public RNGBase {
class PermutationRNG : public RNGBase {
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(PermutationRNG); DEF_OPR_PARAM(PermutationRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
}; };


class ShuffleRNGForward : public OperatorBase { class ShuffleRNGForward : public OperatorBase {
@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG); DEF_OPR_PARAM(ShuffleRNG);


public: public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst,
TensorLayout& indices);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices) = 0;


protected: protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
}; };
using ShuffleRNG = ShuffleRNGForward; using ShuffleRNG = ShuffleRNGForward;


@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG); DEF_OPR_PARAM(ShuffleRNG);


public: public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad) = 0;
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad) = 0;


protected: protected:
void check_exec(const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
}; };


/*! /*!
* \brief sleep for specific time on the computing device; useful for testing * \brief sleep for specific time on the computing device; useful for testing
* async problems * async problems
*/ */
class SleepForward: public OperatorBase {
class SleepForward : public OperatorBase {
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
DEF_OPR_PARAM(Sleep); DEF_OPR_PARAM(Sleep);


public:
virtual void exec() = 0;
public:
virtual void exec() = 0;
}; };
using Sleep = SleepForward; using Sleep = SleepForward;


@@ -149,20 +165,19 @@ using Sleep = SleepForward;
* *
* data must be a one-dimensional contiguous tensor with dtype byte * data must be a one-dimensional contiguous tensor with dtype byte
*/ */
class ChecksumForward: public OperatorBase {
class ChecksumForward : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);


public:
using Result = opr_result::Checksum;
public:
using Result = opr_result::Checksum;


virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;


virtual Result exec(_megdnn_tensor_in data,
_megdnn_workspace workspace) = 0;
virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;


protected:
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
protected:
void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
}; };
using Checksum = ChecksumForward; using Checksum = ChecksumForward;


@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);


public:
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
const TensorLayout& layout2) = 0;
public:
virtual size_t get_workspace_in_bytes(
const TensorLayout& layout1, const TensorLayout& layout2) = 0;


virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
_megdnn_workspace workspace) = 0;
virtual float exec(
_megdnn_tensor_in src1, _megdnn_tensor_in src2,
_megdnn_workspace workspace) = 0;


protected:
void check_exec(const TensorLayout& layout1,
const TensorLayout& layout2, size_t workspace_in_bytes);
protected:
void check_exec(
const TensorLayout& layout1, const TensorLayout& layout2,
size_t workspace_in_bytes);
}; };



bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format);
bool check_bias_share_in_channel(
const TensorLayout& bias, const param::ConvBias::Format format);


} // namespace megdnn } // namespace megdnn




+ 36
- 47
dnn/include/megdnn/tensor_format.h View File

@@ -18,9 +18,9 @@
namespace megdnn { namespace megdnn {


enum class TensorFormat::Type { enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
LOWBITS_ALIGNED_TO_BYTE = 2, //!<
DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
LOWBITS_ALIGNED_TO_BYTE = 2, //!<
}; };


class TensorFormat::ImplBase { class TensorFormat::ImplBase {
@@ -33,8 +33,7 @@ public:


virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;


virtual TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const = 0;
virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0;


virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;


@@ -79,8 +78,7 @@ public:
*/ */
bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;


TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;


TensorLayout::Span span_spec(const TensorLayout& layout) const override; TensorLayout::Span span_spec(const TensorLayout& layout) const override;


@@ -88,8 +86,7 @@ public:
void serialize_append(std::string& result) const override; void serialize_append(std::string& result) const override;


static TensorFormat make(); static TensorFormat make();
static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
}; };


namespace detail { namespace detail {
@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase {
size_t m_align_axis, m_align_size_in_elements_log2; size_t m_align_axis, m_align_size_in_elements_log2;


protected: protected:
Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_elements);
Image2DTensorFormatBase(
Type type, size_t align_axis, size_t align_size_in_elements);
virtual ~Image2DTensorFormatBase() = default; virtual ~Image2DTensorFormatBase() = default;


public: public:
@@ -129,9 +126,7 @@ public:


size_t align_axis() const { return m_align_axis; } size_t align_axis() const { return m_align_axis; }


size_t align_size_in_elements_log2() const {
return m_align_size_in_elements_log2;
}
size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; }


std::string to_string() const override; std::string to_string() const override;


@@ -145,6 +140,7 @@ public:
size_t image_height(const TensorLayout& layout) const; size_t image_height(const TensorLayout& layout) const;


void serialize_append(std::string& result) const override; void serialize_append(std::string& result) const override;

protected: protected:
struct SerializePack { struct SerializePack {
uint8_t align_axis; uint8_t align_axis;
@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
* align COUNT, but mdl needs align size in byte, which equal to * align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size * (image_width algin count) * sizeof(data_type) * pixel_size
*/ */
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
const TensorLayout& layout) const;
size_t image_pitch_alignment_in_bytes(
size_t align_size_in_elements, const TensorLayout& layout) const;


protected: protected:
Image2DPackedTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis,
align_size_in_elements),
Image2DPackedTensorFormatBase(
Type type, size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements),
m_vendor_type(vendor_type) {} m_vendor_type(vendor_type) {}


virtual ~Image2DPackedTensorFormatBase() = default; virtual ~Image2DPackedTensorFormatBase() = default;
@@ -197,13 +192,12 @@ public:


bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;


TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
}; };
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;


/*! /*!
* \brief used for tensors storing lowbit data
* \brief used for tensors storing lowbit data
* *
* \param m_size_nbits size in bits of elements in the tensor * \param m_size_nbits size in bits of elements in the tensor
* \param m_align_size_in_bits aligned size in bits * \param m_align_size_in_bits aligned size in bits
@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;


protected: //? protected: //?
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits,
size_t align_size_in_bits);
LowbitsAlignedTensorFormatBase(
Type type, size_t size_nbits, size_t align_size_in_bits);


virtual ~LowbitsAlignedTensorFormatBase() = default; virtual ~LowbitsAlignedTensorFormatBase() = default;


public: public:
size_t align_size_in_bits() const { return m_align_size_in_bits; } size_t align_size_in_bits() const { return m_align_size_in_bits; }
size_t size_nbits() const { return m_size_nbits; } size_t size_nbits() const { return m_size_nbits; }


std::string to_string() const override; std::string to_string() const override;
@@ -238,8 +232,8 @@ public:


bool is_contiguous_spec(const TensorLayout& layout) const override; bool is_contiguous_spec(const TensorLayout& layout) const override;


TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
protected: protected:
struct SerializePack { struct SerializePack {
uint8_t size_nbits; uint8_t size_nbits;
@@ -254,16 +248,14 @@ protected:
* *
* This is used for OpenCL. * This is used for OpenCL.
*/ */
class Image2DPack4TensorFormat final
: public detail::Image2DPack4TensorFormatBase {
class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase {
public: public:
static constexpr Type TYPE = Type::IMAGE2D_PACK4; static constexpr Type TYPE = Type::IMAGE2D_PACK4;


//! for internal usage or test purposes //! for internal usage or test purposes
static TensorFormat make_raw(size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type =
Handle::HandleVendorType::NOT_SPEC);
static TensorFormat make_raw(
size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC);


static TensorFormat make(size_t align_axis, const Handle* handle); static TensorFormat make(size_t align_axis, const Handle* handle);


@@ -273,13 +265,11 @@ public:
* Note that the alignment may be different if deserialized on another * Note that the alignment may be different if deserialized on another
* handle * handle
*/ */
static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);


static bool is_valid_image(const TensorLayout& layout) { static bool is_valid_image(const TensorLayout& layout) {
if (layout.format.type() == TYPE) { if (layout.format.type() == TYPE) {
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(
layout);
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout);
return true; return true;
} }
return false; return false;
@@ -288,8 +278,9 @@ public:
TensorFormat change_axis(size_t axis) const override; TensorFormat change_axis(size_t axis) const override;


private: private:
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
Image2DPack4TensorFormat(
size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DPack4TensorFormatBase( : detail::Image2DPack4TensorFormatBase(
TYPE, align_axis, align_size_in_elements, vendor_type) {} TYPE, align_axis, align_size_in_elements, vendor_type) {}
}; };
@@ -306,13 +297,12 @@ public:


static TensorFormat make(size_t size_nbits); static TensorFormat make(size_t size_nbits);


static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);


static bool is_valid_layout(const TensorLayout& layout) { static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) { if (layout.format.type() == TYPE) {
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>()
.assert_valid(layout);
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid(
layout);
return true; return true;
} }
return false; return false;
@@ -320,8 +310,7 @@ public:


private: private:
LowbitsAlignedToBytesTensorFormat(size_t size_nbits) LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits,
BYTE_IN_BITS) {}
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {}
}; };
} // namespace megdnn } // namespace megdnn




+ 3
- 5
dnn/include/megdnn/tensor_iter.h View File

@@ -167,13 +167,11 @@ public:


TensorIter(const TensorND& tensor) : m_tensor(tensor) {} TensorIter(const TensorND& tensor) : m_tensor(tensor) {}


Iter begin() const {
return Iter::make(const_cast<TensorND&>(m_tensor), 0);
}
Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); }


Iter end() const { Iter end() const {
return Iter::make(const_cast<TensorND&>(m_tensor),
m_tensor.layout.total_nr_elems());
return Iter::make(
const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems());
} }
}; };
/*! /*!


+ 5
- 5
dnn/include/megdnn/thin/function.h View File

@@ -11,19 +11,19 @@


#pragma once #pragma once


#include <type_traits>
#include <cstdlib>
#include <functional> #include <functional>
#include <utility>
#include <memory> #include <memory>
#include <cstdlib>
#include <type_traits>
#include <utility>


#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"


namespace megdnn { namespace megdnn {
template<typename Signature>
template <typename Signature>
using thin_function = ::std::function<Signature>; using thin_function = ::std::function<Signature>;


} // namespace megdnn
} // namespace megdnn


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"




+ 42
- 76
dnn/include/megdnn/thin/small_vector.h View File

@@ -58,18 +58,16 @@ protected:
m_end_ptr(first_elm), m_end_ptr(first_elm),
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} m_capacity_ptr(static_cast<char*>(first_elm) + size) {}


void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes,
size_t type_size);
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);


public: public:
size_t size_in_bytes() const { size_t size_in_bytes() const {
return size_t(static_cast<char*>(m_end_ptr) -
static_cast<char*>(m_begin_ptr));
return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr));
} }


size_t capacity_in_bytes() const { size_t capacity_in_bytes() const {
return size_t(static_cast<char*>(m_capacity_ptr) -
static_cast<char*>(m_begin_ptr));
return size_t(
static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr));
} }


bool empty() const { return m_begin_ptr == m_end_ptr; } bool empty() const { return m_begin_ptr == m_end_ptr; }
@@ -85,20 +83,15 @@ private:
U m_first_elm; U m_first_elm;


protected: protected:
SmallVectorTemplateCommon(size_t size)
: SmallVectorBase(&m_first_elm, size) {}
SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {}


void grow_pod(size_t min_sz_in_bytes, size_t type_size) { void grow_pod(size_t min_sz_in_bytes, size_t type_size) {
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size);
} }


bool is_small() {
return m_begin_ptr == static_cast<const void*>(&m_first_elm);
}
bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); }


void reset_to_small() {
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm;
}
void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; }


void set_end(T* p) { m_end_ptr = p; } void set_end(T* p) { m_end_ptr = p; }


@@ -128,20 +121,12 @@ protected:
public: public:
// forwarding iterator creation // forwarding iterator creation
iterator begin() { return static_cast<iterator>(m_begin_ptr); } iterator begin() { return static_cast<iterator>(m_begin_ptr); }
const_iterator begin() const {
return static_cast<const_iterator>(m_begin_ptr);
}
const_iterator cbegin() const {
return static_cast<const_iterator>(m_begin_ptr);
}
const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); }
const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); }


iterator end() { return static_cast<iterator>(m_end_ptr); } iterator end() { return static_cast<iterator>(m_end_ptr); }
const_iterator end() const {
return static_cast<const_iterator>(m_end_ptr);
}
const_iterator cend() const {
return static_cast<const_iterator>(m_end_ptr);
}
const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); }
const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); }


reference at(size_type idx) { reference at(size_type idx) {
if (idx >= size()) { if (idx >= size()) {
@@ -167,13 +152,9 @@ public:


// reverse iterator creation method. // reverse iterator creation method.
reverse_iterator rbegin() { return reverse_iterator(end()); } reverse_iterator rbegin() { return reverse_iterator(end()); }
const_reverse_iterator rbegin() const {
return const_reverse_iterator(end());
}
const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); }
reverse_iterator rend() { return reverse_iterator(begin()); } reverse_iterator rend() { return reverse_iterator(begin()); }
const_reverse_iterator rend() const {
return const_reverse_iterator(begin());
}
const_reverse_iterator rend() const { return const_reverse_iterator(begin()); }


pointer data() { return pointer(begin()); } pointer data() { return pointer(begin()); }
const_pointer data() const { return const_pointer(begin()); } const_pointer data() const { return const_pointer(begin()); }
@@ -207,8 +188,8 @@ protected:


template <typename It1, typename It2> template <typename It1, typename It2>
static void uninitialized_move(It1 first, It1 last, It2 dest) { static void uninitialized_move(It1 first, It1 last, It2 dest) {
std::uninitialized_copy(std::make_move_iterator(first),
std::make_move_iterator(last), dest);
std::uninitialized_copy(
std::make_move_iterator(first), std::make_move_iterator(last), dest);
} }


template <typename It1, typename It2> template <typename It1, typename It2>
@@ -293,9 +274,7 @@ protected:
memcpy(dest, first, (last - first) * sizeof(T)); memcpy(dest, first, (last - first) * sizeof(T));
} }


void grow(size_t min_sz = 0) {
this->grow_pod(min_sz * sizeof(T), sizeof(T));
}
void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); }


public: public:
void push_back(const T& _elm) { void push_back(const T& _elm) {
@@ -318,8 +297,7 @@ public:
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N
*/ */
template <typename T> template <typename T>
class SmallVectorImpl
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>;


public: public:
@@ -329,8 +307,7 @@ public:


protected: protected:
explicit SmallVectorImpl(unsigned n) explicit SmallVectorImpl(unsigned n)
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {
}
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {}


public: public:
SmallVectorImpl(const SmallVectorImpl&) = delete; SmallVectorImpl(const SmallVectorImpl&) = delete;
@@ -354,8 +331,7 @@ public:
} else if (n > this->size()) { } else if (n > this->size()) {
if (this->capacity() < n) if (this->capacity() < n)
this->grow(n); this->grow(n);
for (auto it = this->end(), end = this->begin() + n; it != end;
++it)
for (auto it = this->end(), end = this->begin() + n; it != end; ++it)
new (&*it) T(); new (&*it) T();
this->set_end(this->begin() + n); this->set_end(this->begin() + n);
} }
@@ -389,10 +365,11 @@ public:
void swap(SmallVectorImpl<T>& rhs); void swap(SmallVectorImpl<T>& rhs);


/// Add the specified range to the end of the SmallVector. /// Add the specified range to the end of the SmallVector.
template <typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
template <
typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void append(in_iter in_start, in_iter in_end) { void append(in_iter in_start, in_iter in_end) {
size_type num_inputs = std::distance(in_start, in_end); size_type num_inputs = std::distance(in_start, in_end);
// Grow allocated space if needed. // Grow allocated space if needed.
@@ -432,10 +409,11 @@ public:
std::uninitialized_fill(this->begin(), this->end(), elm); std::uninitialized_fill(this->begin(), this->end(), elm);
} }


template <typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
template <
typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void assign(in_iter in_start, in_iter in_end) { void assign(in_iter in_start, in_iter in_end) {
clear(); clear();
append(in_start, in_end); append(in_start, in_end);
@@ -571,8 +549,7 @@ public:
std::fill_n(it, num_overwritten, elm); std::fill_n(it, num_overwritten, elm);


// Insert the non-overwritten middle part. // Insert the non-overwritten middle part.
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten,
elm);
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm);
return it; return it;
} }


@@ -646,8 +623,7 @@ public:
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
this->grow(); this->grow();
} }
new (static_cast<void*>(this->end()))
T(std::forward<ArgTypes>(args)...);
new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...);
this->set_end(this->end() + 1); this->set_end(this->end() + 1);
} }


@@ -661,13 +637,11 @@ public:
return std::equal(this->begin(), this->end(), rhs.begin()); return std::equal(this->begin(), this->end(), rhs.begin());
} }


bool operator!=(const SmallVectorImpl<T>& rhs) const {
return !(*this == rhs);
}
bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); }


bool operator<(const SmallVectorImpl<T>& rhs) const { bool operator<(const SmallVectorImpl<T>& rhs) const {
return std::lexicographical_compare(this->begin(), this->end(),
rhs.begin(), rhs.end());
return std::lexicographical_compare(
this->begin(), this->end(), rhs.begin(), rhs.end());
} }
}; };


@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
// Copy over the extra elms. // Copy over the extra elms.
if (this->size() > rhs.size()) { if (this->size() > rhs.size()) {
size_t elm_diff = this->size() - rhs.size(); size_t elm_diff = this->size() - rhs.size();
this->uninitialized_move(this->begin() + num_shared, this->end(),
rhs.end());
this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end());
rhs.set_end(rhs.end() + elm_diff); rhs.set_end(rhs.end() + elm_diff);
this->destroy_range(this->begin() + num_shared, this->end()); this->destroy_range(this->begin() + num_shared, this->end());
this->set_end(this->begin() + num_shared); this->set_end(this->begin() + num_shared);
} else if (rhs.size() > this->size()) { } else if (rhs.size() > this->size()) {
size_t elm_diff = rhs.size() - this->size(); size_t elm_diff = rhs.size() - this->size();
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(),
this->end());
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end());
this->set_end(this->end() + elm_diff); this->set_end(this->end() + elm_diff);
this->destroy_range(rhs.begin() + num_shared, rhs.end()); this->destroy_range(rhs.begin() + num_shared, rhs.end());
rhs.set_end(rhs.begin() + num_shared); rhs.set_end(rhs.begin() + num_shared);
@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
} }


template <typename T> template <typename T>
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(
const SmallVectorImpl<T>& rhs) {
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) {
if (this == &rhs) if (this == &rhs)
return *this; return *this;
size_t rhs_sz = rhs.size(); size_t rhs_sz = rhs.size();
@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(
} else if (cur_sz) { } else if (cur_sz) {
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin());
} }
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(),
this->begin() + cur_sz);
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
this->set_end(this->begin() + rhs_sz); this->set_end(this->begin() + rhs_sz);
return *this; return *this;
} }
@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) {
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin());
} }


this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(),
this->begin() + cur_sz);
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);


this->set_end(this->begin() + rhs_sz); this->set_end(this->begin() + rhs_sz);


@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> {
public: public:
SmallVector() : SmallVectorImpl<T>(N) {} SmallVector() : SmallVectorImpl<T>(N) {}


explicit SmallVector(size_t size, const T& value = T())
: SmallVectorImpl<T>(N) {
explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) {
this->assign(size, value); this->assign(size, value);
} }


@@ -901,15 +869,13 @@ namespace std {


/// Implement std::swap in terms of SmallVector swap. /// Implement std::swap in terms of SmallVector swap.
template <typename T> template <typename T>
inline void swap(megdnn::SmallVectorImpl<T>& lhs,
megdnn::SmallVectorImpl<T>& rhs) {
inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) {
lhs.swap(rhs); lhs.swap(rhs);
} }


/// Implement std::swap in terms of SmallVector swap. /// Implement std::swap in terms of SmallVector swap.
template <typename T, unsigned N> template <typename T, unsigned N>
inline void swap(megdnn::SmallVector<T, N>& lhs,
megdnn::SmallVector<T, N>& rhs) {
inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) {
lhs.swap(rhs); lhs.swap(rhs);
} }
} // end namespace std } // end namespace std


+ 6
- 6
dnn/include/megdnn/version.h View File

@@ -17,13 +17,13 @@
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"


namespace megdnn { namespace megdnn {
struct Version {
int major, minor, patch;
};
struct Version {
int major, minor, patch;
};


//! get megdnn version of the binary
Version get_version();
}
//! get megdnn version of the binary
Version get_version();
} // namespace megdnn


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"




+ 26
- 24
dnn/src/aarch64/conv_bias/fp16/algos.cpp View File

@@ -22,18 +22,17 @@ using namespace aarch64;
/* ===================== stride-2 algo ===================== */ /* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)


bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
bool ConvBiasImpl::AlgoF16DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
return param.filter_meta.format == param::Convolution::Format::NCHW && return param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 && param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 && param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7); (FH == 2 || FH == 3 || FH == 5 || FH == 7);
} }
MIDOUT_END(); MIDOUT_END();
@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
return 0; return 0;
} }


SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
return get_kimpls(param); return get_kimpls(param);
@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
return {}; return {};
} }


SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
auto fm = param.filter_meta; auto fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t group = fm.group; size_t group = fm.group;
bool large_group = group >= param.nr_threads; bool large_group = group >= param.nr_threads;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
using Func = std::function<void(
const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t,
size_t)>;
Func conv = nullptr; Func conv = nullptr;
if (FH == 2) { if (FH == 2) {
conv = fp16::conv_stride2::do_conv_2x2_stride2; conv = fp16::conv_stride2::do_conv_2x2_stride2;
@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else { } else {
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto do_conv = [bundle, conv](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
ncb_index.ndrange_id);
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({do_conv, {group, N, OC}}); ret_kerns.push_back({do_conv, {group, N, OC}});
} }


+ 5
- 5
dnn/src/aarch64/conv_bias/fp16/algos.h View File

@@ -18,13 +18,13 @@ namespace aarch64 {
/* ===================== stride-2 algo ===================== */ /* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;

public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMV8F16STRD2"; } const char* name() const override { return "ARMV8F16STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;


size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;




+ 48
- 67
dnn/src/aarch64/conv_bias/fp16/stride2_kern.h View File

@@ -20,9 +20,9 @@ namespace aarch64 {
namespace fp16 { namespace fp16 {
namespace conv_stride2 { namespace conv_stride2 {


static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_2x2_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3; size_t width = OW >> 3;
size_t mod4_left = width & 3; size_t mod4_left = width & 3;
@@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
"5: \n" "5: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1)
: "r"(mod4_left), "w"(_k0123) : "r"(mod4_left), "w"(_k0123)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v28", "v29", "v30", "v31");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
} }
} }


static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_3x3_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3; size_t width = OW >> 3;
size_t mod3_left = width % 3; size_t mod3_left = width % 3;
@@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
"3: \n" "3: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2)
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
} }
} }


static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_5x5_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3; size_t width = OW >> 3;
size_t mod2_left = width & 1; size_t mod2_left = width & 1;
@@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
const __fp16* r4 = src_ptr + IW * 4; const __fp16* r4 = src_ptr + IW * 4;


register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") =
MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") =
MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") =
MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") =
MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") =
MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") =
MEGDNN_SIMD_SET1(filter[24]);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]);


for (size_t i = 0; i < OH; i++) { for (size_t i = 0; i < OH; i++) {
asm volatile( asm volatile(
@@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
"bne 2b \n" "bne 2b \n"
"3: \n" "3: \n"


: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2),
"+r"(r3), "+r"(r4)
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4)
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415),
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424),
"r"(mod2_left)
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31");
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left)
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
} }
} }


static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_7x7_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3; size_t width = OW >> 3;


@@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
const __fp16* r6 = src_ptr + IW * 6; const __fp16* r6 = src_ptr + IW * 6;


register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") =
MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") =
MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") =
MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") =
MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") =
MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") =
MEGDNN_SIMD_LOADU(filter + 24);
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") =
MEGDNN_SIMD_LOADU(filter + 28);
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") =
MEGDNN_SIMD_LOADU(filter + 32);
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") =
MEGDNN_SIMD_LOADU(filter + 36);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24);
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28);
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32);
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36);
register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = register MEGDNN_SIMD_TYPE _k40414243 asm("v10") =
MEGDNN_SIMD_LOADU(filter + 40); MEGDNN_SIMD_LOADU(filter + 40);
register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = register MEGDNN_SIMD_TYPE _k44454647 asm("v11") =
MEGDNN_SIMD_LOADU(filter + 44); MEGDNN_SIMD_LOADU(filter + 44);
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") =
MEGDNN_SIMD_SET1(filter[48]);
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]);


for (size_t i = 0; i < OH; i++) { for (size_t i = 0; i < OH; i++) {
asm volatile( asm volatile(
@@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
"bne 2b \n" "bne 2b \n"
"3: \n" "3: \n"


: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4), "+r"(r5), "+r"(r6)
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4),
"+r"(r5), "+r"(r6)
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011),
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), "w"(_k12131415), "w"(_k16171819), "w"(_k20212223),
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), "w"(_k24252627), "w"(_k28293031), "w"(_k32333435),
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647),
"w"(_k48484848)
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848)
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;


+ 30
- 31
dnn/src/aarch64/conv_bias/fp32/algos.cpp View File

@@ -21,18 +21,17 @@ using namespace megdnn;
using namespace aarch64; using namespace aarch64;


MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32)
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) {
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
return param.filter_meta.format == param::ConvBias::Format::NCHW && return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 && param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7); (FH == 2 || FH == 3 || FH == 5 || FH == 7);
} }
MIDOUT_END(); MIDOUT_END();
@@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
MIDOUT_END(); MIDOUT_END();
return 0; return 0;
} }
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param); return get_kimpls(param);
@@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
return {}; return {};
} }


SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
auto fm = param.filter_meta; auto fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];
@@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t group = fm.group; size_t group = fm.group;
bool large_group = group >= param.nr_threads; bool large_group = group >= param.nr_threads;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
using Func = std::function<void(
const float*, const float*, float*, size_t, size_t, size_t, size_t,
size_t)>;
Func conv = nullptr; Func conv = nullptr;
if (FH == 2) { if (FH == 2) {
conv = fp32::conv_stride2::do_conv_2x2_stride2; conv = fp32::conv_stride2::do_conv_2x2_stride2;
@@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
conv = fp32::conv_stride2::do_conv_7x7_stride2; conv = fp32::conv_stride2::do_conv_7x7_stride2;
} }


WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, large_group);
WorkspaceBundle bundle =
arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group);
SmallVector<NCBKern> ret_kerns; SmallVector<NCBKern> ret_kerns;


//! Dense conv and small group //! Dense conv and small group
@@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) { for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<float, float>:: arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
} }
for (size_t oc = 0; oc < OC; oc++) { for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
{ncb_index.thread_id,
0, oc});
arm_common::MultithreadDirectConvCommon<float, float>::
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
} }
}; };
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else { } else {
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<float, float>:: arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({copy_padding, {group, N, IC}}); ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto do_conv = [bundle, conv](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr); bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
ncb_index.ndrange_id);
arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
}; };
ret_kerns.push_back({do_conv, {group, N, OC}}); ret_kerns.push_back({do_conv, {group, N, OC}});
} }


+ 5
- 5
dnn/src/aarch64/conv_bias/fp32/algos.h View File

@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl;


class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;

public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMV8F32STRD2"; } const char* name() const override { return "ARMV8F32STRD2"; }


bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;


size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;




+ 33
- 38
dnn/src/aarch64/conv_bias/fp32/stride2_kern.h View File

@@ -16,16 +16,15 @@


namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace fp32{
namespace fp32 {
namespace conv_stride2 { namespace conv_stride2 {



//! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp`


// refer to function do_conv_2x2_stride2_asm_unroll4 // refer to function do_conv_2x2_stride2_asm_unroll4
static void do_conv_2x2_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_2x2_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2; size_t width = OW >> 2;
size_t mod4_left = width & 3; size_t mod4_left = width & 3;
@@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter,
"5: \n" "5: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1)
: "r"(mod4_left), "w"(_k0123) : "r"(mod4_left), "w"(_k0123)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v28", "v29", "v30", "v31");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter,
} }


// refer to function do_conv_3x3_stride2_asm_unroll3 // refer to function do_conv_3x3_stride2_asm_unroll3
static void do_conv_3x3_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_3x3_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2; size_t width = OW >> 2;
size_t mod3_left = width % 3; size_t mod3_left = width % 3;
@@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
"ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6


"ld2 {v5.4s, v6.4s}, [%3], #32 \n" "ld2 {v5.4s, v6.4s}, [%3], #32 \n"
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ...
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ...
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i]
"ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8
"fmla v0.4s, v2.4s, v22.4s \n" "fmla v0.4s, v2.4s, v22.4s \n"
@@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
"3: \n" "3: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2)
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
} }


// refer to function do_conv_5x5_stride2_asm_unroll2 // refer to function do_conv_5x5_stride2_asm_unroll2
static void do_conv_5x5_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_5x5_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2; size_t width = OW >> 2;
size_t mod2_left = width & 1; size_t mod2_left = width & 1;
@@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter,
"bne 2b \n" "bne 2b \n"
"3: \n" "3: \n"


: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2),
"+r"(r3), "+r"(r4)
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4)
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415),
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424),
"r"(mod2_left)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24");
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;
@@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter,
} }


// refer to function do_conv_7x7_stride2_asm_unroll2 // refer to function do_conv_7x7_stride2_asm_unroll2
static void do_conv_7x7_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_7x7_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW; const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2; size_t width = OW >> 2;


@@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter,
"bne 2b \n" "bne 2b \n"
"3: \n" "3: \n"


: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4), "+r"(r5), "+r"(r6)
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4),
"+r"(r5), "+r"(r6)
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011),
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), "w"(_k12131415), "w"(_k16171819), "w"(_k20212223),
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), "w"(_k24252627), "w"(_k28293031), "w"(_k32333435),
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647),
"w"(_k48484848)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18");
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18");


r0 += tail_step; r0 += tail_step;
r1 += tail_step; r1 += tail_step;


+ 36
- 37
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \ M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \ part2 = megdnn::matmul::GemmInterleaved< \
@@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
DISPATCH_GEMM_BIAS(s8_4x4, 0) DISPATCH_GEMM_BIAS(s8_4x4, 0)
} }
#else #else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \ M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \ part2 = megdnn::matmul::GemmInterleaved< \
@@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
return {nullptr, {part0, part1, part2}}; return {nullptr, {part0, part1, part2}};
} }


void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip; auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param); UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param); auto bundle = get_bundle(param);
@@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else { } else {
if (is_xcorr) if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<true>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
else else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<false>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
} }
} }
{ {
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
Workspace workspace(
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2));
size_t M = OC; size_t M = OC;
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);


if (cpuinfo_has_arm_neon_dot()) { if (cpuinfo_has_arm_neon_dot()) {
@@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
DISPATCH_GEMM_BIAS(s8_4x4, 0) DISPATCH_GEMM_BIAS(s8_4x4, 0)
} }
#else #else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \
} \
MIDOUT_END() MIDOUT_END()
DISPATCH_GEMM_BIAS(s8_4x4, 0) DISPATCH_GEMM_BIAS(s8_4x4, 0)
#endif #endif


+ 6
- 8
dnn/src/aarch64/conv_bias/int8/algos.h View File

@@ -12,8 +12,8 @@
#pragma once #pragma once


#include "src/aarch64/conv_bias/opr_impl.h" #include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/opr_impl.h"


namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);


public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "S8MATMUL"; } const char* name() const override { return "S8MATMUL"; }


bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override { size_t get_workspace(const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes(); return get_bundle(param).total_size_in_bytes();
} }
SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override {
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group; size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}}; return {{kimpl, {group, 1_z, 1_z}}};
} }


+ 61
- 70
dnn/src/aarch64/conv_bias/int8/strategy.cpp View File

@@ -29,9 +29,10 @@ struct KernCaller;
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 12> { struct KernCaller<bmode, Op, 8, 12> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
static void run(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace) {
megdnn_assert(is_first_k); megdnn_assert(is_first_k);


constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE = 8;
@@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> {
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12,
is_first_k);
matmul_8x12x4::kern_8x12(
packA, cur_packB, K, workspace, 12, is_first_k);


arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8,
12>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4));


#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \
@@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> {
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12,
is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x12x4::kern_4x12(
packA, cur_packB, K, workspace, 12, is_first_k,
std::min<size_t>(M - m, 4));
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
@@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> {
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb #undef cb


output += 4; output += 4;
@@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> {


template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 4, 4> { struct KernCaller<bmode, Op, 4, 4> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
static void run(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace) {
megdnn_assert(is_first_k); megdnn_assert(is_first_k);


constexpr size_t A_INTERLEAVE = 4; constexpr size_t A_INTERLEAVE = 4;
@@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> {
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4,
4>::postprocess(bias, workspace,
output, LDC, op);
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess(
bias, workspace, output, LDC, op);


output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }


for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace,
4, is_first_k, 4,
std::min<size_t>(N - n, 4));
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, workspace, 4, is_first_k, 4,
std::min<size_t>(N - n, 4));
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
@@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> {
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb #undef cb
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
@@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> {


MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity)


void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_s8_4x4_nobias_identity::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_s8_4x4_nobias_identity::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
@@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const {
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity)


void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_s8_8x12_nobias_identity::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t);
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t);
if (transpose) { if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_s8_8x12_nobias_identity::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
@@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {


#endif #endif


#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace); \
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \
dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \
dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \
} }


#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
@@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#endif #endif
#undef DEFINE_OP #undef DEFINE_OP


#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C);
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op( \
scale_A* scale_B, scale_A* scale_B, scale_C);
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#endif #endif
#undef DEFINE_OP #undef DEFINE_OP




+ 25
- 26
dnn/src/aarch64/conv_bias/int8/strategy.h View File

@@ -20,43 +20,42 @@ namespace matmul {
* *
* \name gemm_<type>_<block>_biasmode_nolinemode * \name gemm_<type>_<block>_biasmode_nolinemode
*/ */
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16,
false, true,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity);
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4,
false, true,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity);
#endif #endif


} // namespace matmul } // namespace matmul


+ 12
- 13
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -13,13 +13,13 @@
#include "src/aarch64/conv_bias/int8/algos.h" #include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/quint8/algos.h" #include "src/aarch64/conv_bias/quint8/algos.h"


#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"


#include "src/fallback/convolution/opr_impl.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp16/algos.h" #include "src/aarch64/conv_bias/fp16/algos.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/fallback/convolution/opr_impl.h"


using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
@@ -56,12 +56,10 @@ public:
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const {
return m_direct_algos; return m_direct_algos;
} }
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos()
const {
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const {
return m_matmul_algos; return m_matmul_algos;
} }
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

}; };


const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
@@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {


MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)


SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
algos.insert(
algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
//! We put matmul algos at the begin. Because matmul will get privilege when //! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See //! prefer return true. See
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
algos.insert(
algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 1
- 1
dnn/src/aarch64/conv_bias/opr_impl.h View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once #pragma once
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/utils.h"


namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {


+ 36
- 37
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \ M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \ part2 = megdnn::matmul::GemmInterleaved< \
@@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0);
} }
#else #else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \ M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \ part2 = megdnn::matmul::GemmInterleaved< \
@@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
return {nullptr, {part0, part1, part2}}; return {nullptr, {part0, part1, part2}};
} }


void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip; auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param); UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param); auto bundle = get_bundle(param);
@@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else { } else {
if (is_xcorr) if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<true>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
else else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<false>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
} }
} }
{ {
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
Workspace workspace(
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2));
size_t M = OC; size_t M = OC;
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);


if (cpuinfo_has_arm_neon_dot()) { if (cpuinfo_has_arm_neon_dot()) {
@@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
} }
#else #else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \
} \
MIDOUT_END() MIDOUT_END()


DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)


+ 6
- 8
dnn/src/aarch64/conv_bias/quint8/algos.h View File

@@ -12,8 +12,8 @@
#pragma once #pragma once


#include "src/aarch64/conv_bias/opr_impl.h" #include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/opr_impl.h"


namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); static void kimpl(const NCBKernParam& param, const NCBKernIndex&);


public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "QU8MATMUL"; } const char* name() const override { return "QU8MATMUL"; }


bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override { size_t get_workspace(const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes(); return get_bundle(param).total_size_in_bytes();
} }
SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override {
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group; size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}}; return {{kimpl, {group, 1_z, 1_z}}};
} }


+ 93
- 97
dnn/src/aarch64/conv_bias/quint8/strategy.cpp View File

@@ -14,8 +14,8 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h" #include "src/arm_common/conv_bias/matmul_postprocess.h"


using namespace megdnn; using namespace megdnn;
@@ -29,10 +29,10 @@ struct KernCaller;
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8, true> { struct KernCaller<bmode, Op, 8, 8, true> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
static void run(
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K,
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k); megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8; constexpr size_t B_INTERLEAVE = 8;
@@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> {
size_t n = 0; size_t n = 0;
const dt_uint8* cur_packB = packB; const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B, zAB);
matmul_8x8x4::kern_8x8(
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB);


arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B, zAB);
matmul_8x8x4::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4), zp_A, zp_B, zAB);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
@@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> {
const dt_uint8* cur_packB = packB; const dt_uint8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B, zAB);
matmul_8x8x4::kern_4x8(
packA, cur_packB, K, workspace, 8, is_first_k,
std::min<size_t>(M - m, 4), zp_A, zp_B, zAB);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
@@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> {
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B,
zAB);
matmul_8x8x4::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A,
zp_B, zAB);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb #undef cb


output += 4; output += 4;
@@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> {


template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8, false> { struct KernCaller<bmode, Op, 8, 8, false> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
static void run(
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K,
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k); megdnn_assert(is_first_k);


constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE = 8;
@@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> {
size_t n = 0; size_t n = 0;
const dt_uint8* cur_packB = packB; const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B);
matmul_8x8x8::kern_8x8(
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B);


arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B);
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4), zp_A, zp_B);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb #undef cb



output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> {
const dt_uint8* cur_packB = packB; const dt_uint8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B);
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, workspace, 8, is_first_k,
std::min<size_t>(M - m, 4), zp_A, zp_B);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
@@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> {
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B);
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A,
zp_B);
#define cb(m, n) \ #define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op); bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb #undef cb



output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> {
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity)


void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_dot_nobias_identity::pack_A(
uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0,
ymax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin,
y0, ymax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_dot_nobias_identity::pack_B(
uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
xmax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(
out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax,
k0, kmax);
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(
out, in, ldin, x0, xmax, k0, kmax);
} }
} }


@@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const {


#endif #endif
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity)
void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr,
const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_u8_8x8_nodot_nobias_identity::pack_A(
dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax, zA);
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax, zA);
} else { } else {
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax, zA);
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA);
} }
} }


void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_nodot_nobias_identity::pack_B(
dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax, zB);
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax, zB);
} else { } else {
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax,
zB);
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB);
} }
} }


@@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32); return 8 * 8 * sizeof(dt_int32);
} }


#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \
_OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \
dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \
zp_B); \
} }


#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
@@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#undef DEFINE_OP #undef DEFINE_OP


#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C);
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op( \
scale_A* scale_B, scale_A* scale_B, scale_C, zp_C);
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu,
FuseAddReluOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#endif #endif
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity,
AddOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu,
FuseAddReluOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#undef DEFINE_OP #undef DEFINE_OP


#undef KERN #undef KERN


+ 26
- 28
dnn/src/aarch64/conv_bias/quint8/strategy.h View File

@@ -16,46 +16,44 @@ namespace aarch64 {
namespace matmul { namespace matmul {


#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4,
false, true,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true,
gemm_u8_8x8_dot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity);


#endif #endif
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8,
false, true,
gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true,
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 4
- 4
dnn/src/aarch64/handle.cpp View File

@@ -11,11 +11,11 @@


#include "src/common/handle_impl.h" #include "src/common/handle_impl.h"


#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/handle.h" #include "src/aarch64/handle.h"
#include "src/aarch64/matrix_mul/opr_impl.h" #include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/relayout/opr_impl.h" #include "src/aarch64/relayout/opr_impl.h"
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/warp_perspective/opr_impl.h" #include "src/aarch64/warp_perspective/opr_impl.h"


namespace megdnn { namespace megdnn {
@@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective)
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop


} // namespace aarch64
} // namespace megdnn
} // namespace aarch64
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 10
- 12
dnn/src/aarch64/handle.h View File

@@ -14,20 +14,18 @@
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {


class HandleImpl: public arm_common::HandleImpl {
public:
HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64):
arm_common::HandleImpl::HandleImpl(computing_handle, type)
{}
class HandleImpl : public arm_common::HandleImpl {
public:
HandleImpl(
megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64)
: arm_common::HandleImpl::HandleImpl(computing_handle, type) {}


template <typename Opr>
std::unique_ptr<Opr> create_operator();
template <typename Opr>
std::unique_ptr<Opr> create_operator();
}; };


} // namespace aarch64
} // namespace megdnn
} // namespace aarch64
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen



+ 327
- 438
dnn/src/aarch64/matrix_mul/algos.cpp
File diff suppressed because it is too large
View File


+ 28
- 76
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -21,9 +21,7 @@ namespace aarch64 {


class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32K8X12X1"; } const char* name() const override { return "AARCH64_F32K8X12X1"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -35,8 +33,7 @@ public:
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
@@ -48,9 +45,7 @@ public:


class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32K4X16X1"; } const char* name() const override { return "AARCH64_F32K4X16X1"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -61,9 +56,7 @@ public:


class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } const char* name() const override { return "AARCH64_F32_MK4_4x16"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -73,8 +66,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16)
}; };


class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv {
public: public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::AARCH64; m_handle_type = Handle::HandleType::AARCH64;
@@ -85,9 +77,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F16_K8X24X1"; } const char* name() const override { return "AARCH64_F16_K8X24X1"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -98,9 +88,7 @@ public:


class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } const char* name() const override { return "AARCH64_F16_MK8_8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -115,12 +103,8 @@ public:
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
@@ -130,12 +114,8 @@ public:


class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
@@ -147,8 +127,7 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
@@ -163,9 +142,7 @@ public:


class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
@@ -178,9 +155,7 @@ public:


class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
@@ -192,9 +167,7 @@ public:


class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
@@ -207,9 +180,7 @@ public:


class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
@@ -222,8 +193,7 @@ public:
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
@@ -238,12 +208,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -257,12 +224,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -276,8 +240,7 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
@@ -292,9 +255,7 @@ public:


class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
@@ -306,9 +267,7 @@ public:


class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
@@ -321,12 +280,8 @@ public:
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_QUINT8_K8X8X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
@@ -336,8 +291,7 @@ public:
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
} }
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
@@ -352,9 +306,7 @@ public:
#endif #endif
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; }
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;


+ 393
- 458
dnn/src/aarch64/matrix_mul/asm/common.h
File diff suppressed because it is too large
View File


+ 534
- 555
dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
File diff suppressed because it is too large
View File


+ 4
- 4
dnn/src/aarch64/matrix_mul/fp16/strategy.h View File

@@ -16,11 +16,11 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false,
true, hgemm_8x24);
MEGDNN_REG_GEMM_STRATEGY(
dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24);


MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1,
false, true, gemm_nopack_f16_8x8);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 21
- 20
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"


@@ -21,8 +21,9 @@ using namespace aarch64::matmul;


namespace { namespace {


void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x1(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
LDB *= sizeof(dt_float16); LDB *= sizeof(dt_float16);
asm volatile( asm volatile(
".arch armv8.2-a+fp16\n" ".arch armv8.2-a+fp16\n"
@@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
// |v23[0-7]| |v27[0-7]| // |v23[0-7]| |v27[0-7]|
// +--------+ +--------+ // +--------+ +--------+
// Accumulator // Accumulator
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x4(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! LDB means number of elements in one block in B. we will read 24 numbers //! LDB means number of elements in one block in B. we will read 24 numbers
//! first. so minus 24 * 2 bytes here. //! first. so minus 24 * 2 bytes here.
LDB = (LDB - 24) * sizeof(dt_float16); LDB = (LDB - 24) * sizeof(dt_float16);
@@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory");
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
// | v7[0-7]| |v31[0-7]| // | v7[0-7]| |v31[0-7]|
// +--------+ +--------+ // +--------+ +--------+
// Accumulator // Accumulator
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x8(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112
//! here. //! here.
LDB = (LDB - 32) * sizeof(dt_float16); LDB = (LDB - 32) * sizeof(dt_float16);
@@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31", "cc", "memory");
} }


} // anonymous namespace } // anonymous namespace


MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8);


void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N,
const dt_float16*, void*, bool trA,
bool trB) const {
void gemm_nopack_f16_8x8::kern(
const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8; constexpr static size_t MB = 8;
constexpr static size_t KB = 8; constexpr static size_t KB = 8;
constexpr static size_t NB = 8; constexpr static size_t NB = 8;


+ 13
- 11
dnn/src/aarch64/matrix_mul/fp32/common.h View File

@@ -17,21 +17,23 @@
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {


MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packA_n(
const float* A, float* Apacked, size_t M, size_t K, size_t LDA,
const float* alpha);


MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packA_t(
const float* A, float* Apacked, size_t M, size_t K, size_t LDA,
const float* alpha);


MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_packB_n(
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB);


MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_packB_t(
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB);


MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C,
size_t LDC, size_t M, size_t N, size_t K,
int type, const float* beta);
MEGDNN_NOINLINE void sgemm_kernel12x8(
const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N,
size_t K, int type, const float* beta);


} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 124
- 118
dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h View File

@@ -12,7 +12,6 @@
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"



namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul_general_4x16 { namespace matmul_general_4x16 {
@@ -39,8 +38,9 @@ namespace matmul_general_4x16 {
// +--+ - - - - +--------+--------+--------+--------+ // +--+ - - - - +--------+--------+--------+--------+
// //
// Accumulator // Accumulator
void kern_4x16(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain) {
void kern_4x16(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K,


"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr) [m_remain] "+r"(m_remain), [outptr] "+r"(outptr)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9",
"x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc",
"memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K,
// +--+--+ - - - - +--------+ // +--+--+ - - - - +--------+
// //
// Accumulator // Accumulator
void kern_4x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int m_remain, int n_remain) {
void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
STORE_LINE("6", "2") \ STORE_LINE("6", "2") \
STORE_LINE("7", "3") \ STORE_LINE("7", "3") \
"105:\n" "105:\n"
// clang-format on
asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"2: \n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"6:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1",
"x2", "x3", "x10", "cc", "memory");
// clang-format on
asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"2: \n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"6:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10",
"cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
#undef STORE_C #undef STORE_C
} }


void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
void sgemm_4x16_pack_A_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
float zerobuff[4]; float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4); std::memset(zerobuff, 0, sizeof(float) * 4);
constexpr int PACK_SIZE = 4*4;
constexpr int PACK_SIZE = 4 * 4;


int y = y0; int y = y0;
for (; y + 3 < ymax; y += 4) { for (; y + 3 < ymax; y += 4) {
// printf("main loop pack_a_n %p \n",outptr);
// printf("main loop pack_a_n %p \n",outptr);
const float* inptr0 = inptr + y * ldin + k0; const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin; const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin; const float* inptr2 = inptr1 + ldin;
@@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
switch ((y + 3) - ymax) { switch ((y + 3) - ymax) {
/* Everything falls through in here */ /* Everything falls through in here */
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
} }
} }


void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
void sgemm_4x16_pack_A_t(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0; int ksize = kmax - k0;
int ksize4 = (ksize << 2); int ksize4 = (ksize << 2);
float* outptr_base = out; float* outptr_base = out;
@@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
auto outptr = outptr_base; auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) { for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4; outptr += ksize4;
} }


@@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
} }
} }


void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
void sgemm_4x16_pack_B_n(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0; int ksize = kmax - k0;
int ksize16 = ksize * 16; int ksize16 = ksize * 16;
int ksize4 = (ksize << 2); int ksize4 = (ksize << 2);
@@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
auto outptr = outptr_base; auto outptr = outptr_base;
for (; x + 16 <= xmax; x += 16) { for (; x + 16 <= xmax; x += 16) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize16; outptr += ksize16;
} }
outptr = outptr_base4; outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) { for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4; outptr += ksize4;
} }


@@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
} }
} }


void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
int y0, int ymax, int k0, int kmax) {
void sgemm_4x16_pack_B_t(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) {
float* outptr = out; float* outptr = out;
const float* inptr = in; const float* inptr = in;
float zerobuff[4]; float zerobuff[4];
@@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,


int x = (kmax - k0); int x = (kmax - k0);
for (; x > 3; x -= 4) { for (; x > 3; x -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner,
64);
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64);
outptr_inner += 64; outptr_inner += 64;
} }
for (; x > 0; x--) { for (; x > 0; x--) {
@@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
switch ((y + 3) - ymax) { switch ((y + 3) - ymax) {
/* Everything falls through in here */ /* Everything falls through in here */
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
switch ((y + 3) - ymax) { switch ((y + 3) - ymax) {
/* Everything falls through in here */ /* Everything falls through in here */
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
} }
} }


} // matmul_general_4x16
} // aarch64
} // megdnn
} // namespace matmul_general_4x16
} // namespace aarch64
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 54
- 60
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -306,14 +307,13 @@ struct matmul_general_8x12 {
"6:\n" "6:\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -348,9 +348,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -520,13 +520,12 @@ struct matmul_general_8x12 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc",
"memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -557,9 +556,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -717,13 +716,12 @@ struct matmul_general_8x12 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x10", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -754,9 +752,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -895,20 +893,21 @@ struct matmul_general_8x12 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
#undef STORE_C #undef STORE_C
} }


static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_A_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
float zerobuff[8]; float zerobuff[8];
std::memset(zerobuff, 0, sizeof(float) * 8); std::memset(zerobuff, 0, sizeof(float) * 8);
constexpr int PACK_SIZE_32 = 4 * 8; constexpr int PACK_SIZE_32 = 4 * 8;
@@ -933,8 +932,9 @@ struct matmul_general_8x12 {
prefetch_2x(inptr7); prefetch_2x(inptr7);
int x = (kmax - k0); int x = (kmax - k0);
for (; x > 3; x -= 4) { for (; x > 3; x -= 4) {
transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_8x4_1_s(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += PACK_SIZE_32; outptr += PACK_SIZE_32;
} }
for (; x > 0; x--) { for (; x > 0; x--) {
@@ -1004,8 +1004,8 @@ struct matmul_general_8x12 {
} }
} }


static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_A_t(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0; int ksize = kmax - k0;
int ksize8 = (ksize << 3); int ksize8 = (ksize << 3);
int ksize4 = (ksize << 2); int ksize4 = (ksize << 2);
@@ -1028,20 +1028,17 @@ struct matmul_general_8x12 {
auto outptr = outptr_base; auto outptr = outptr_base;
for (; x + 8 <= xmax; x += 8) { for (; x + 8 <= xmax; x += 8) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize8; outptr += ksize8;
} }
outptr = outptr_base4; outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) { for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4; outptr += ksize4;
} }
if (x < xmax) { if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4,
xmax - x);
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
} }
outptr_base += 4 * 8; outptr_base += 4 * 8;
outptr_base4 += 4 * 4; outptr_base4 += 4 * 4;
@@ -1071,8 +1068,8 @@ struct matmul_general_8x12 {
} }
} }


static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_B_n(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0; int ksize = kmax - k0;
int ksize12 = ksize * 12; int ksize12 = ksize * 12;
int ksize4 = (ksize << 2); int ksize4 = (ksize << 2);
@@ -1095,20 +1092,17 @@ struct matmul_general_8x12 {
auto outptr = outptr_base; auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) { for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize12; outptr += ksize12;
} }
outptr = outptr_base4; outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) { for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4; outptr += ksize4;
} }
if (x < xmax) { if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4,
xmax - x);
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
} }
outptr_base += 12 * 4; outptr_base += 12 * 4;
outptr_base4 += 4 * 4; outptr_base4 += 4 * 4;
@@ -1138,8 +1132,8 @@ struct matmul_general_8x12 {
} }
} }


static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_B_t(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) {
float* outptr = out; float* outptr = out;
const float* inptr = in; const float* inptr = in;
float zerobuff[12]; float zerobuff[12];
@@ -1172,9 +1166,9 @@ struct matmul_general_8x12 {
prefetch_2x(inptr11); prefetch_2x(inptr11);
int x = (kmax - k0); int x = (kmax - k0);
for (; x > 3; x -= 4) { for (; x > 3; x -= 4) {
transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, inptr8, inptr9,
inptr10, inptr11, outptr);
transpose_12x4_1_s(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
outptr += 48; outptr += 48;
} }
for (; x > 0; x--) { for (; x > 0; x--) {


+ 34
- 37
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 {
"6:\n" "6:\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
} }
@@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9",
"x10", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21",
"x22", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE


+ 34
- 37
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 {
"6:\n" "6:\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
} }
@@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc",
"memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
int oddk = (K & 1); int oddk = (K & 1);
@@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C "6:\n" STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE


+ 28
- 27
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -307,10 +308,10 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1) [output0] "+r"(output0), [output1] "+r"(output1)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -340,9 +341,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -500,8 +501,8 @@ struct matmul_mk4_8x12 {
[output0] "+r"(output0), [output1] "+r"(output1), [output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain) [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C
@@ -531,8 +532,9 @@ struct matmul_mk4_8x12 {
// //
// Accumulator // Accumulator


static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -669,9 +671,9 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0) [output0] "+r"(output0)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -697,9 +699,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -818,15 +820,15 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain) [output0] "+r"(output0), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C
} }


static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_A(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_SIZE_32 = 4 * 8; constexpr int PACK_SIZE_32 = 4 * 8;
@@ -855,8 +857,8 @@ struct matmul_mk4_8x12 {
} }
} }


static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f}; float tmpbuff[16] = {0.0f};


@@ -886,8 +888,7 @@ struct matmul_mk4_8x12 {
outptr += ksize4; outptr += ksize4;
} }
if (x < xmax) { if (x < xmax) {
std::memcpy(tmpbuff, inptr,
sizeof(float) * (xmax - x) * PACK_C_SIZE);
std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr; auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0]; const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);


+ 23
- 22
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1) [output0] "+r"(output0), [output1] "+r"(output1)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 {
[output0] "+r"(output0), [output1] "+r"(output1), [output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain) [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "x8", "x9", "x10", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C
@@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 {
// //
// Accumulator // Accumulator


static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0) [output0] "+r"(output0)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x8", "x9", "x10", "x11", "x12", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain) [output0] "+r"(output0), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C


+ 23
- 22
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+--------+--------+ // +--+ --- - +--------+--------+--------+
// //
// Accumulator // Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1) [output0] "+r"(output0), [output1] "+r"(output1)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
float* output0 = output; float* output0 = output;
@@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 {
[output0] "+r"(output0), [output1] "+r"(output1), [output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain) [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "x8", "x9", "x10", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C
@@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 {
// //
// Accumulator // Accumulator


static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0) [output0] "+r"(output0)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x8", "x9", "x10", "x11", "x12", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+ // +--+ --- - +--------+
// //
// Accumulator // Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC); MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA; const float* a_ptr = packA;
const float* b_ptr = packB; const float* b_ptr = packB;
@@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain) [output0] "+r"(output0), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");


#undef LOAD_C #undef LOAD_C
#undef STORE_C #undef STORE_C


+ 83
- 86
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp View File

@@ -10,6 +10,7 @@
* implied. * implied.
*/ */


#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h"
@@ -17,44 +18,40 @@
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h" #include "src/common/utils.h"



using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
using namespace aarch64::matmul; using namespace aarch64::matmul;


MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);


void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
void sgemm_4x16::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
if (transpose_A) { if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
} }
} }


void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
void sgemm_4x16::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
if (transpose_B) { if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_4x16::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t n = 0; size_t n = 0;
const float* cur_packB = packB; const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
matmul_general_4x16::kern_4x16(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K16; cur_packB += K16;
} }
@@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,


MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);


void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
void sgemm_8x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
if (transpose_A) { if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
} }
} }


void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
void sgemm_8x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
if (transpose_B) { if (transpose_B) {
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


template <typename gemm_class> template <typename gemm_class>
static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
static inline void sgemm_8x12_helper(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4; constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12; constexpr size_t B_INTERLEAVE = 12;
@@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_class::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t n = 0; size_t n = 0;
const float* cur_packB = packB; const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
gemm_class::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
gemm_class::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
} }
} }


void sgemm_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_8x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
#if !MGB_ENABLE_CPUINFO #if !MGB_ENABLE_CPUINFO
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k);
#else #else
auto arch = cpuinfo_get_current_core()->uarch; auto arch = cpuinfo_get_current_core()->uarch;
#ifdef __IN_TEE_ENV__ #ifdef __IN_TEE_ENV__
arch = cpuinfo_uarch_unknown; arch = cpuinfo_uarch_unknown;
#endif #endif
if (arch == cpuinfo_uarch_cortex_a53) { if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_8x12_helper<matmul_general_8x12_a53>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) { } else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_8x12_helper<matmul_general_8x12_a55>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else { } else {
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_8x12_helper<matmul_general_8x12>(
packA, packB, M, N, K, C, LDC, is_first_k);
} }
#endif #endif
} }


MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);


void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose_A) const {
void sgemm_mk4_8x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A");
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax);
} }


void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
void sgemm_mk4_8x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B");
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
} }


template <typename gemm_name> template <typename gemm_name>
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
static inline void sgemm_mk4_8x12_helper(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
const int K12 = K * 12; const int K12 = K * 12;
const int K8 = K * 8; const int K8 = K * 8;
const int K4 = K * 4; const int K4 = K * 4;
@@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_name::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE; output += 4 * PACK_C_SIZE;
cur_packB += K4; cur_packB += K4;
} }
@@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
cur_packB += K12; cur_packB += K12;
} }
for (; n < N; n += 4) { for (; n < N; n += 4) {
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_name::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE; output += 4 * PACK_C_SIZE;
cur_packB += K4; cur_packB += K4;
} }
packA += K4; packA += K4;
} }
} }
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_mk4_8x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
#if !MGB_ENABLE_CPUINFO #if !MGB_ENABLE_CPUINFO
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k);
#else #else
auto arch = cpuinfo_get_current_core()->uarch; auto arch = cpuinfo_get_current_core()->uarch;
#ifdef __IN_TEE_ENV__ #ifdef __IN_TEE_ENV__
arch = cpuinfo_uarch_unknown; arch = cpuinfo_uarch_unknown;
#endif #endif
if (arch == cpuinfo_uarch_cortex_a53) { if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) { } else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else { } else {
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(
packA, packB, M, N, K, C, LDC, is_first_k);
} }
#endif #endif
} }


+ 5
- 8
dnn/src/aarch64/matrix_mul/fp32/strategy.h View File

@@ -15,17 +15,14 @@
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);


MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16);


MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false,
sgemm_mk4_8x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12);


MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 19
- 23
dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp View File

@@ -20,8 +20,8 @@ using namespace aarch64::matmul;


namespace { namespace {


void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x1(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
LDB *= sizeof(float); LDB *= sizeof(float);
asm volatile( asm volatile(
"subs %w[K], %w[K], #4\n" "subs %w[K], %w[K], #4\n"
@@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc",
"memory");
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+ - - - - -+--------+ // +--------+ - - - - -+--------+
// Accumulator // Accumulator


void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x4(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12
//! here. //! here.
LDB = (LDB - 12) * sizeof(float); LDB = (LDB - 12) * sizeof(float);
@@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+ - - - - -+--------+ // +--------+ - - - - -+--------+
// Accumulator // Accumulator


void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x8(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24
//! here. //! here.
LDB = (LDB - 24) * sizeof(float); LDB = (LDB - 24) * sizeof(float);
@@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
"v27", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc",
"memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+ // +--------+
// Accumulator // Accumulator


void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K,
float* output) {
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) {
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56
//! here. //! here.
LDB = (LDB - 56) * sizeof(float); LDB = (LDB - 56) * sizeof(float);
@@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
} }


} // namespace } // namespace


MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16);


void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B,
size_t LDB, float* C, size_t LDC, size_t M,
size_t K, size_t N, const float*, void*, bool trA,
bool trB) const {
void sgemm_nopack_4x16::kern(
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC,
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const {
constexpr static size_t MB = 4; constexpr static size_t MB = 4;
constexpr static size_t KB = 4; constexpr static size_t KB = 4;
constexpr static size_t NB = 16; constexpr static size_t NB = 16;


+ 99
- 96
dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h View File

@@ -46,8 +46,9 @@ namespace matmul_12x8x1 {
* Accumulator * Accumulator
*/ */


static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_12x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
"stp q25, q26, [x9]\n" "stp q25, q26, [x9]\n"
"stp q27, q28, [x10]\n" "stp q27, q28, [x10]\n"
"stp q29, q30, [x11]\n" "stp q29, q30, [x11]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
: :
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
"cc", "memory");
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
@@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
"stp q17, q18, [x5]\n" "stp q17, q18, [x5]\n"
"stp q19, q20, [x6]\n" "stp q19, q20, [x6]\n"
"stp q21, q22, [x7]\n" "stp q21, q22, [x7]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
: :
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4",
"x5", "x6", "x7", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
@@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
: :
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"cc", "memory");
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc",
"memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
@@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_12x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
[outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9),
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11),
[x0] "+r"(x0), [n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9),
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "cc", "memory");
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
: :
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"cc", "memory");
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc",
"memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB; const int16_t* b_ptr = packB;


@@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain), [x1] "+r"(x1),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1),
[n_remain] "+r"(n_remain) [n_remain] "+r"(n_remain)
: :
: "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory");
@@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s16_12x8x1_pack_A_n(
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int16_t zerobuff[4]; int16_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int16_t) * 4); std::memset(zerobuff, 0, sizeof(int16_t) * 4);


@@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 3; K -= 4) { for (; K > 3; K -= 4) {
interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10,
inptr11, outptr);
interleave_12x1_4_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
} }


if (K > 0) { if (K > 0) {
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11,
outptr, 1, K);
interleave_12(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr, 1, K);
} }
} }


@@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 7; K -= 8) { for (; K > 7; K -= 8) {
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x1_8_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 1, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 1, K);
} }
} }


@@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
} }
} }


static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in,
int ldin, int x0, int xmax,
int k0, int kmax) {
static void gemm_s16_12x8x1_transpose_pack_A_n(
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0; const int ksize = kmax - k0;
const int ksize4 = ksize * 4; const int ksize4 = ksize * 4;
const int ksize8 = ksize4 * 2; const int ksize8 = ksize4 * 2;
@@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in,
} }
} }


static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s16_12x8x1_pack_B_n(
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0; const int ksize = kmax - k0;
const int ksize4 = ksize * 4; const int ksize4 = ksize * 4;
const int ksize8 = ksize4 * 2; const int ksize8 = ksize4 * 2;
@@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin,
} }
} }


static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
const int16_t* inptr, int ldin,
int y0, int ymax, int k0,
int kmax) {
static void gemm_s16_12x8x1_transpose_pack_B_n(
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int16_t zerobuff[4]; int16_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int16_t) * 4); std::memset(zerobuff, 0, sizeof(int16_t) * 4);


@@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 7; K -= 8) { for (; K > 7; K -= 8) {
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x1_8_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 1, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 1, K);
} }
} }


@@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;


+ 35
- 36
dnn/src/aarch64/matrix_mul/int16/strategy.cpp View File

@@ -22,39 +22,37 @@ using namespace aarch64::matmul;
///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// ///////////////////////// gemm_s16_12x8x1 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1);


void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s16_12x8x1::pack_A(
dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin,
y0, ymax, k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax,
k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s16_12x8x1::pack_B(
dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int16 &&
C_dtype.enumv() == DTypeEnum::Int32),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s16_12x8x1::kern(
const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int16 &&
C_dtype.enumv() == DTypeEnum::Int32),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
size_t n = 0; size_t n = 0;
const dt_int16* cur_packB = packB; const dt_int16* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_12x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
const dt_int16* cur_packB = packB; const dt_int16* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
const dt_int16* cur_packB = packB; const dt_int16* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
matmul_12x8x1::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }


+ 4
- 4
dnn/src/aarch64/matrix_mul/int16/strategy.h View File

@@ -16,11 +16,11 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true,
gemm_s16_12x8x1);
MEGDNN_REG_GEMM_STRATEGY(
dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1);


MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false,
true, gemm_nopack_s16_8x8);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 23
- 21
dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"


@@ -20,8 +20,9 @@ using namespace aarch64::matmul;


namespace { namespace {


void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x1(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here. //! here.
LDB *= sizeof(dt_int16); LDB *= sizeof(dt_int16);
@@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
// | v31[0-7]| |v23[0-3]| // | v31[0-7]| |v23[0-3]|
// +---------+ +--------+ // +---------+ +--------+
// Accumulator // Accumulator
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x4(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here. //! here.
LDB = (LDB - 24) * sizeof(dt_int16); LDB = (LDB - 24) * sizeof(dt_int16);
@@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
} }


// Overview of register layout: // Overview of register layout:
@@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
// | v7[0-7]| |v30[0-3]|v31[0-3]| // | v7[0-7]| |v30[0-3]|v31[0-3]|
// +--------+ +--------+--------+ // +--------+ +--------+--------+
// Accumulator // Accumulator
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x8(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48
//! here. //! here.
LDB = (LDB - 48) * sizeof(dt_int16); LDB = (LDB - 48) * sizeof(dt_int16);
@@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB) [output] "+r"(output), [LDB] "+r"(LDB)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
} }


} // anonymous namespace } // anonymous namespace


MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8);


void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
size_t LDB, dt_int32* C, size_t LDC, size_t M,
size_t K, size_t N, const dt_int32*, void*,
bool trA, bool trB) const {
void gemm_nopack_s16_8x8::kern(
const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C,
size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8; constexpr static size_t MB = 8;
constexpr static size_t KB = 8; constexpr static size_t KB = 8;
constexpr static size_t NB = 8; constexpr static size_t NB = 8;


+ 121
- 103
dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h View File

@@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 {
* Accumulator * Accumulator
*/ */


static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void s4_kern_8x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
@@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"dup v5.8b,v20.b[5]\n" "dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n" "dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n" "dup v7.8b,v20.b[7]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n" "ld1 {v17.8b}, [%[b_ptr]], 8\n"


"dup v8.8b,v20.b[8]\n" "dup v8.8b,v20.b[8]\n"
@@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
STORE_C STORE_C


: :
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain),
[n_remain] "+r"(
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
: :
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void s4_kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
// clang-format off
// clang-format off


#define LOAD_C_8 \ #define LOAD_C_8 \
"ld1 {v24.8h}, [x0], #16\n" \ "ld1 {v24.8h}, [x0], #16\n" \
@@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"st1 {v28.8h}, [x4], #16\n" \ "st1 {v28.8h}, [x4], #16\n" \
"st1 {v29.8h}, [x5], #16\n" \ "st1 {v29.8h}, [x5], #16\n" \
"st1 {v30.8h}, [x6], #16\n" \ "st1 {v30.8h}, [x6], #16\n" \
"st1 {v31.8h}, [x7], #16\n" \
"st1 {v31.8h}, [x7], #16\n"


// clang-format on
// clang-format on
register int16_t* outptr asm("x0") = output; register int16_t* outptr asm("x0") = output;
asm volatile( asm volatile(
"add x1, x0, %x[LDC]\n" "add x1, x0, %x[LDC]\n"
@@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n" "PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n" "PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n" "1:\n"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
"dup v0.8b,v20.b[0]\n" "dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n" "ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n" "dup v1.8b,v20.b[1]\n"
@@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"dup v5.8b,v20.b[5]\n" "dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n" "dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n" "dup v7.8b,v20.b[7]\n"


"dup v8.8b,v20.b[8]\n" "dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n" "smlal v24.8h, v0.8b, v16.8b\n"
@@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
STORE_C_8 STORE_C_8


: :
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain),
[n_remain] "+r"(
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
: :
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
#undef STORE_C #undef STORE_C
} }
//packa
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
// packa
static void gemm_s4x4x16_8x8x8_transpose_pack(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[8]; int8_t zerobuff[8];
int8_t tmpbuff0[8]; int8_t tmpbuff0[8];
int8_t tmpbuff1[8]; int8_t tmpbuff1[8];
@@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
prefetch_2x(inptr5); prefetch_2x(inptr5);
prefetch_2x(inptr6); prefetch_2x(inptr6);
prefetch_2x(inptr7); prefetch_2x(inptr7);
int K = (kmax - k0)/2;
int K = (kmax - k0) / 2;
//! read 4 * 16 in each row //! read 4 * 16 in each row
for (; K > 3; K -= 4) { for (; K > 3; K -= 4) {
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
std::memcpy(tmpbuff0, inptr0, K);
std::memcpy(tmpbuff1, inptr1, K);
std::memcpy(tmpbuff2, inptr2, K);
std::memcpy(tmpbuff3, inptr3, K);
std::memcpy(tmpbuff4, inptr4, K);
std::memcpy(tmpbuff5, inptr5, K);
std::memcpy(tmpbuff6, inptr6, K);
std::memcpy(tmpbuff7, inptr7, K);
inptr0 = tmpbuff0; inptr0 = tmpbuff0;
inptr1 = tmpbuff1; inptr1 = tmpbuff1;
inptr2 = tmpbuff2; inptr2 = tmpbuff2;
@@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
inptr5 = tmpbuff5; inptr5 = tmpbuff5;
inptr6 = tmpbuff6; inptr6 = tmpbuff6;
inptr7 = tmpbuff7; inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }
} }
for (; y < ymax; y += 8) { for (; y < ymax; y += 8) {
@@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
const int8_t* inptr6 = inptr5 + ldin; const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin; const int8_t* inptr7 = inptr6 + ldin;


int K = (kmax - k0)/2;
int K = (kmax - k0) / 2;
//! read 4 * 16 in each row //! read 4 * 16 in each row
for (; K > 3; K -= 4) { for (; K > 3; K -= 4) {
if (y + 7 >= ymax) { if (y + 7 >= ymax) {
switch (y + 7 - ymax) { switch (y + 7 - ymax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
megdnn_assert(0); megdnn_assert(0);
} }
} }
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }
if (K > 0) { if (K > 0) {
if (y + 7 >= ymax) { if (y + 7 >= ymax) {
switch (y + 7 - ymax) { switch (y + 7 - ymax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
} }
} }


std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
std::memcpy(tmpbuff0, inptr0, K);
std::memcpy(tmpbuff1, inptr1, K);
std::memcpy(tmpbuff2, inptr2, K);
std::memcpy(tmpbuff3, inptr3, K);
std::memcpy(tmpbuff4, inptr4, K);
std::memcpy(tmpbuff5, inptr5, K);
std::memcpy(tmpbuff6, inptr6, K);
std::memcpy(tmpbuff7, inptr7, K);
inptr0 = tmpbuff0; inptr0 = tmpbuff0;
inptr1 = tmpbuff1; inptr1 = tmpbuff1;
inptr2 = tmpbuff2; inptr2 = tmpbuff2;
@@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
inptr5 = tmpbuff5; inptr5 = tmpbuff5;
inptr6 = tmpbuff6; inptr6 = tmpbuff6;
inptr7 = tmpbuff7; inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }
} }
} }
//packb
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
// packb
static void gemm_s4x4x16_8x8x8_interleave_pack(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[8]; int8_t zerobuff[8];
int8_t tmpbuff0[8]; int8_t tmpbuff0[8];
int8_t tmpbuff1[8]; int8_t tmpbuff1[8];
@@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); std::memset(tmpbuff6, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); std::memset(tmpbuff7, 0, sizeof(int8_t) * 8);
const int ksize = kmax - k0; const int ksize = kmax - k0;
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4
const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4
int8_t* outptr = out; int8_t* outptr = out;
int8_t* outptr_interleave = nullptr; int8_t* outptr_interleave = nullptr;


@@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
int8_t* outptr_inner = outptr; int8_t* outptr_inner = outptr;
for (; x + 3 < xmax; x += 4) { for (; x + 3 < xmax; x += 4) {
outptr_interleave = outptr_inner; outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8; outptr_inner += ksize8;
} }


if (x < xmax) { if (x < xmax) {
int remainx = xmax - x; int remainx = xmax - x;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
std::memcpy(tmpbuff0, inptr0, remainx);
std::memcpy(tmpbuff1, inptr1, remainx);
std::memcpy(tmpbuff2, inptr2, remainx);
std::memcpy(tmpbuff3, inptr3, remainx);
std::memcpy(tmpbuff4, inptr4, remainx);
std::memcpy(tmpbuff5, inptr5, remainx);
std::memcpy(tmpbuff6, inptr6, remainx);
std::memcpy(tmpbuff7, inptr7, remainx);
inptr0 = tmpbuff0; inptr0 = tmpbuff0;
inptr1 = tmpbuff1; inptr1 = tmpbuff1;
inptr2 = tmpbuff2; inptr2 = tmpbuff2;
@@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
inptr7 = tmpbuff7; inptr7 = tmpbuff7;


outptr_interleave = outptr_inner; outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8; outptr_inner += ksize8;
} }
outptr += 64; outptr += 64;
@@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
break; break;
} }
outptr_interleave = outptr_inner; outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8; outptr_inner += ksize8;
} }
if (x < xmax) { if (x < xmax) {
@@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
} }
int remainx = xmax - x; int remainx = xmax - x;
outptr_interleave = outptr_inner; outptr_interleave = outptr_inner;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
std::memcpy(tmpbuff0, inptr0, remainx);
std::memcpy(tmpbuff1, inptr1, remainx);
std::memcpy(tmpbuff2, inptr2, remainx);
std::memcpy(tmpbuff3, inptr3, remainx);
std::memcpy(tmpbuff4, inptr4, remainx);
std::memcpy(tmpbuff5, inptr5, remainx);
std::memcpy(tmpbuff6, inptr6, remainx);
std::memcpy(tmpbuff7, inptr7, remainx);
inptr0 = tmpbuff0; inptr0 = tmpbuff0;
inptr1 = tmpbuff1; inptr1 = tmpbuff1;
inptr2 = tmpbuff2; inptr2 = tmpbuff2;
@@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
inptr7 = tmpbuff7; inptr7 = tmpbuff7;


outptr_interleave = outptr_inner; outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8; outptr_inner += ksize8;
} }
} }
} }


} // namespace matmul_4x4x16
} // namespace matmul_s4_4x4x16
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn



// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 33
- 34
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp View File

@@ -10,9 +10,9 @@
* implied. * implied.
*/ */


#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
@@ -23,39 +23,38 @@ using namespace aarch64::matmul;


// ===========================gemm_s4x4x16_s4_8x8x8================================== // ===========================gemm_s4x4x16_s4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8);
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s4x4x16_s4_8x8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(
out, in, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(
out, in, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s4x4x16_s4_8x8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(
out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(
out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::QuantizedS4 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s4x4x16_s4_8x8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::QuantizedS4 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
matmul_s4_4x4x16::s4_kern_8x8(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
B_INTERLEAVE);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_s4_4x4x16::s4_kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }
@@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_s4_4x4x16::s4_kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }
@@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
} }
} }



// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 2
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h View File

@@ -17,8 +17,8 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s4x4x16_s4_8x8x8);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 57
- 39
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h View File

@@ -51,8 +51,9 @@ namespace matmul_4x4x16 {
* Accumulator * Accumulator
*/ */


static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 16; K /= 16;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
); );
} }


static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
int m_remain, int n_remain) {
static void kern_4x4_remain(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
megdnn_assert(K > 0); megdnn_assert(K > 0);
K /= 16; K /= 16;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
@@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,


STORE_C STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0, int kmax) {
static void gemm_s8_4x4_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
} }
} }


static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_4x4_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) { if (remain >= 0) {
switch (remain) { switch (remain) {
case 7: case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
transpose_4x16_1_b_helper(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4; outptr_inner += ksize4;
} }


@@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) { if (remain >= 0) {
switch (remain) { switch (remain) {
case 7: case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;


+ 163
- 113
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h View File

@@ -42,8 +42,9 @@ namespace matmul_8x8x8 {
* Accumulator * Accumulator
*/ */


static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"stp q18, q19, [x5]\n" "stp q18, q19, [x5]\n"
"stp q20, q21, [x6]\n" "stp q20, q21, [x6]\n"
"stp q22, q23, [x7]\n" "stp q22, q23, [x7]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v26", "v27", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"cc", "memory");
} }


/** /**
@@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[x1] "+r"(x1), [m_remain] "+r"(m_remain),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain) [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc",
@@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void gemm_s8_8x8_pack_A_n(
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,


int K = kmax - k0; int K = kmax - k0;
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x8_2_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
} }
} }


@@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
} }
} }


static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax) {
static void gemm_s8_8x8_transpose_pack_A_n(
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
megdnn_assert(0); megdnn_assert(0);
} }
} }
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += ksize8; outptr += ksize8;
} }


@@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
} }
} }


transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, 4);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, 4);
outptr += ksize4; outptr += ksize4;
} }


@@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
} }
} }


transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, xmax - x);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, xmax - x);
} }


outptr_base += 8 * 8; outptr_base += 8 * 8;
@@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
} }
} }


static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x8_pack_B_n(
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
} }
} }
outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr += ksize8; outptr += ksize8;
} }


@@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
} }


outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, 4);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, 4);
outptr += ksize4; outptr += ksize4;
} }


@@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
} }


outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, xmax - x);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, xmax - x);
} }


outptr_base += 8 * 8; outptr_base += 8 * 8;
@@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
} }
} }


static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x8_transpose_pack_B_n(
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
constexpr int interleave4 = 32; constexpr int interleave4 = 32;
@@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 7; K -= 8) { for (; K > 7; K -= 8) {
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += interleave8; outptr += interleave8;
} }


if (K > 0) { if (K > 0) {
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
outptr += interleave8; outptr += interleave8;
} }
} }
@@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;


+ 47
- 41
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h View File

@@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 {
* Accumulator * Accumulator
*/ */


static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output,
bool is_first_k) {
K = div_ceil(K, 16); K = div_ceil(K, 16);
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"6:\n" "6:\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[k] "+r"(K), [output] "+r"(output)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
} }


static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k, size_t remain_n) {
static void kern_4x4_remain(
const int8_t* packA, const int8_t* packB, int K, int32_t* output,
bool is_first_k, size_t remain_n) {
K = div_ceil(K, 16); K = div_ceil(K, 16);
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,


"7:\n" "7:\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k),
[k] "+r"(K), [output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n),
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
} }


static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_mk4_s8_4x4_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)}
int8_t zerobuff[4][64]; int8_t zerobuff[4][64];
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4);
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
"mk4 matmul with k is not times of 4");
megdnn_assert(
ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(
kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
"mk4 matmul with k is not times of 4");
size_t roundk = round_up(kmax - k0, 16); size_t roundk = round_up(kmax - k0, 16);
size_t out_offset = roundk * 4; size_t out_offset = roundk * 4;
int y = y0; int y = y0;
@@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
prefetch_2x(inptr3); prefetch_2x(inptr3);
int K = kmax - k0; int K = kmax - k0;
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
transpose_interleave_4x4_4_b(
inptr0, inptr1, inptr2, inptr3, output, out_offset);
output += 64; output += 64;
} }
if (K > 0) { if (K > 0) {
@@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
inptr1 = zerobuff[1]; inptr1 = zerobuff[1];
inptr2 = zerobuff[2]; inptr2 = zerobuff[2];
inptr3 = zerobuff[3]; inptr3 = zerobuff[3];
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
transpose_interleave_4x4_4_b(
inptr0, inptr1, inptr2, inptr3, output, out_offset);
output += 64; output += 64;
} }
} }
@@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
} }
} }


static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_mk4_s8_4x4_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int32_t zerobuff[4]; int32_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
const int ICB = (ksize) / 4; const int ICB = (ksize) / 4;
const int ksize4 = round_up<int>(ICB, 4) * 4; const int ksize4 = round_up<int>(ICB, 4) * 4;
int32_t* outptr = reinterpret_cast<int32_t*>(out); int32_t* outptr = reinterpret_cast<int32_t*>(out);
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
"mk4 matmul with k is not times of 4");
megdnn_assert(
kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
"mk4 matmul with k is not times of 4");


int k = k0 / 4; int k = k0 / 4;
for (; k + 3 < ICB; k += 4) { for (; k + 3 < ICB; k += 4) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 = const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 = const int32_t* inptr2 =
@@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
outptr += 4 * 4; outptr += 4 * 4;
} }
if (k < ICB) { if (k < ICB) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 = const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 = const int32_t* inptr2 =
@@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
if (k + 3 >= ICB) { if (k + 3 >= ICB) {
switch (k + 3 - ICB) { switch (k + 3 - ICB) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
if (k + 3 >= ICB) { if (k + 3 >= ICB) {
switch (k + 3 - ICB) { switch (k + 3 - ICB) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


} // namespace matmul_4x4x16
} // namespace matmul_mk4_4x4x16
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn




+ 77
- 82
dnn/src/aarch64/matrix_mul/int8/strategy.cpp View File

@@ -24,20 +24,19 @@ using namespace aarch64::matmul;
///////////////////////// gemm_s8_4x4 //////////////////////////////////// ///////////////////////// gemm_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4);


void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_4x4::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
@@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
} }
} }


void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }


for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4,
std::min<size_t>(N - n, 4));
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, output, LDC, is_first_k, 4,
std::min<size_t>(N - n, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }
@@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// ///////////////////////// gemm_mk4_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4);


void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose A");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
void gemm_mk4_s8_4x4::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
megdnn_assert(
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }


void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose B");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
void gemm_mk4_s8_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax);
} }


void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_mk4_s8_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output,
is_first_k);
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k);
output += B_INTERLEAVE * 4; output += B_INTERLEAVE * 4;
cur_packB += K4; cur_packB += K4;
} }


if (n < N) { if (n < N) {
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output,
is_first_k, N - n);
matmul_mk4_4x4x16::kern_4x4_remain(
packA, cur_packB, K, output, is_first_k, N - n);
} }


packA += K4; packA += K4;
} }
} }



///////////////////////// gemm_s8_8x8 //////////////////////////////////// ///////////////////////// gemm_s8_8x8 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8);


void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_8x8::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax);
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax);
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }


+ 6
- 6
dnn/src/aarch64/matrix_mul/int8/strategy.h View File

@@ -16,14 +16,14 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true,
gemm_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4);


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false,
gemm_mk4_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4);


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
gemm_s8_8x8);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8);


} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64


+ 62
- 59
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h View File

@@ -52,8 +52,9 @@ namespace matmul_8x12x4 {


#if 1 #if 1
MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
} }
#else #else
MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
"stp q15, q23, [%[outptr7]]\n" "stp q15, q23, [%[outptr7]]\n"
"str q31, [%[outptr7], #32]\n" "str q31, [%[outptr7], #32]\n"


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0),
[a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0),
[b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1),
[a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1),
[b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7)
: :
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
} }


#endif #endif
@@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain) {
static void kern_4x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,


"4:\n" STORE_C "4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain),
[LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0),
[b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a),
[b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "=r"(x0)
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0),
[a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2),
[b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "=r"(x0)
: :
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "memory", "cc");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "memory", "cc");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0)
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4),
[outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
[x0] "=r"(x0)
: :
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory",
"cc");
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 4; K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
@@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"4:\n" STORE_C "4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0),
[a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0),
[b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1)
: :
: "v4", "v5", "v6", "v7", "memory", "cc"); : "v4", "v5", "v6", "v7", "memory", "cc");
@@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x12_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int K = kmax - k0; int K = kmax - k0;
//! read 8 * 4 in each row //! read 8 * 4 in each row
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x4_4_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, K);
} }
} }
for (; y < ymax; y += 4) { for (; y < ymax; y += 4) {
@@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
} }
} }


static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x12_pack_A_t(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x12_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x12_pack_B_t(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr,
int K = kmax - k0; int K = kmax - k0;
//! read 12 * 4 in each row //! read 12 * 4 in each row
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10,
inptr11, outptr);
interleave_12x4_4_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
} }


if (K > 0) { if (K > 0) {
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11,
outptr, 4, K);
interleave_12(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr, 4, K);
} }
} }
for (; y < ymax; y += 4) { for (; y < ymax; y += 4) {


+ 30
- 32
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h View File

@@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 {
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_4x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
"stp q16, q17, [%[outptr0], #128]\n" "stp q16, q17, [%[outptr0], #128]\n"
"stp q18, q19, [%[outptr0], #160]\n" "stp q18, q19, [%[outptr0], #160]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a),
[b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a),
[b1a] "=w"(b1a), [b2a] "=w"(b2a)
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1),
[b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a)
: :
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "memory", "cc");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "memory", "cc");
} }


// Overview of register layout: // Overview of register layout:
@@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4; K /= 4;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[x0] "=r"(x0)
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0)
: :
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory",
"cc");
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4; K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
@@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,


"4:\n" STORE_C "4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a),
[x0] "=r"(x0)
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0)
: :
: "v4", "v5", "v6", "v7", "memory", "cc"); : "v4", "v5", "v6", "v7", "memory", "cc");


@@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 matmul with k is not times of 4");
static void gemm_mk4_s8_8x12_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4");
int y = y0; int y = y0;
int start_y = y0 / 4; int start_y = y0 / 4;
for (; y + 7 < ymax; y += 8, start_y += 2) { for (; y + 7 < ymax; y += 8, start_y += 2) {
@@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr,
interleave_2x4_4_b(inptr0, inptr1, outptr); interleave_2x4_4_b(inptr0, inptr1, outptr);
} }
} }
for (; y + 3 < ymax; y += 4, start_y ++) {
for (; y + 3 < ymax; y += 4, start_y++) {
int K = kmax - k0; int K = kmax - k0;
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2);
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4);
} }
} }


static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_mk4_s8_8x12_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0; const int ksize = kmax - k0;
const int ksize12 = ksize * 12; const int ksize12 = ksize * 12;
const int ksize4 = ksize * 4; const int ksize4 = ksize * 4;


+ 59
- 59
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp View File

@@ -12,10 +12,10 @@
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#if MGB_ENABLE_DOT #if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"


using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
@@ -24,20 +24,19 @@ using namespace aarch64::matmul;
/* ====================== gemm_s8_8x12 ===========================*/ /* ====================== gemm_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);


void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_8x12::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_8x12::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
@@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
} }
} }


void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_8x12::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());


MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
@@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
matmul_8x12x4::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
/* ====================== gemm_mk4_s8_8x12 ===========================*/ /* ====================== gemm_mk4_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12);


void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
void gemm_mk4_s8_8x12::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
megdnn_assert(
!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax);
} }


void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported");
void gemm_mk4_s8_8x12::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(
!transpose, "matrix mul mk4 with transposed matrix B is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
} }


void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_mk4_s8_8x12::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());


MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
@@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += (B_INTERLEAVE << 2); output += (B_INTERLEAVE << 2);
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_mk4_8x12x4::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 16; output += 16;
cur_packB += K4; cur_packB += K4;
} }
@@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += (B_INTERLEAVE << 2); output += (B_INTERLEAVE << 2);
cur_packB += K12; cur_packB += K12;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_mk4_8x12x4::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 16; output += 16;
cur_packB += K4; cur_packB += K4;
} }


+ 5
- 5
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h View File

@@ -16,14 +16,14 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12);


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_mk4_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12);


} // namespace aarch64
} // namespace matmul } // namespace matmul
} // namespace aarch64
} // namespace megdnn } // namespace megdnn
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 54
- 38
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h View File

@@ -34,9 +34,9 @@ namespace matmul_4x4x16 {
* *
* Accumulator * Accumulator
*/ */
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 16; K /= 16;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
// Store back into memory // Store back into memory
STORE_C STORE_C


: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
: :
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2",
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3",
"v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_4x4_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
} }
} }


static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8x8x16_4x4_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) { if (remain >= 0) {
switch (remain) { switch (remain) {
case 7: case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
transpose_4x16_1_b_helper(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4; outptr_inner += ksize4;
} }


@@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) { if (remain >= 0) {
switch (remain) { switch (remain) {
case 7: case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;


+ 159
- 111
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h View File

@@ -42,8 +42,9 @@ namespace matmul_8x8x8 {
* *
* Accumulator * Accumulator
*/ */
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"bne 2b\n" "bne 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "cc", "memory");
#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
#undef STORE_LINE #undef STORE_LINE
@@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t n_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
* Accumulator * Accumulator
*/ */


static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t m_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n" "cbnz %w[K], 2b\n"


"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc",
"memory"); "memory");
@@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
* *
* Accumulator * Accumulator
*/ */
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
K /= 8; K /= 8;
const int8_t* a_ptr = packA; const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"3:\n" STORE_C "3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr),
[K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC),
[x0] "+r"(x0), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1",
"cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc",
"memory");


#undef LOAD_LINE #undef LOAD_LINE
#undef LOAD_C #undef LOAD_C
@@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C #undef STORE_C
} }


static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_8x8_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x8_2_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
} }


if (K > 0) { if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
} }
} }
for (; y < ymax; y += 4) { for (; y < ymax; y += 4) {
@@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
} }
} }


static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax,
int k0, int kmax) {
static void gemm_s8x8x16_8x8_transpose_pack_A_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);


@@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
megdnn_assert(0); megdnn_assert(0);
} }
} }
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += ksize8; outptr += ksize8;
} }


@@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
} }
} }


transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, 4);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, 4);
outptr += ksize4; outptr += ksize4;
} }


@@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
} }
} }


transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, xmax - x);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, xmax - x);
} }


outptr_base += 8 * 8; outptr_base += 8 * 8;
@@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
} }
} }


static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8x8x16_8x8_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0; const int ksize = kmax - k0;
@@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }
} }
outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr += ksize8; outptr += ksize8;
} }


@@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }


outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, 4);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, 4);
outptr += ksize4; outptr += ksize4;
} }


@@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) { if (k + 7 >= kmax) {
switch (k + 7 - kmax) { switch (k + 7 - kmax) {
case 6: case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5: case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4: case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3: case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2: case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr7 = zerobuff; inptr7 = zerobuff;
break; break;
@@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }


outptr_interleave = outptr; outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, xmax - x);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, xmax - x);
} }


outptr_base += 8 * 8; outptr_base += 8 * 8;
@@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} }
} }


static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
const dt_int8* inptr, int ldin,
int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_8x8_transpose_pack_B_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16]; int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16); std::memset(zerobuff, 0, sizeof(int8_t) * 16);
constexpr int interleave4 = 32; constexpr int interleave4 = 32;
@@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,


int K = kmax - k0; int K = kmax - k0;
for (; K > 7; K -= 8) { for (; K > 7; K -= 8) {
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += interleave8; outptr += interleave8;
} }


if (K > 0) { if (K > 0) {
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
outptr += interleave8; outptr += interleave8;
} }
} }
@@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;
@@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
if (y + 3 >= ymax) { if (y + 3 >= ymax) {
switch (y + 3 - ymax) { switch (y + 3 - ymax) {
case 2: case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1: case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0: case 0:
inptr3 = zerobuff; inptr3 = zerobuff;
break; break;


+ 34
- 41
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h View File

@@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 {
* Accumulator * Accumulator
*/ */
// clang-format on // clang-format on
static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k,
int remain_n) {
static __attribute__((noinline)) void kern_16x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4; K /= 4;
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
"6:\n" STORE_C "6:\n" STORE_C


"101:\n" "101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");


#undef STORE_C #undef STORE_C
#undef STORE_LINE #undef STORE_LINE
@@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
* Accumulator * Accumulator
*/ */
// clang-format on // clang-format on
static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k, int remain_n) {
static __attribute__((noinline)) void kern_8x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4; K /= 4;
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
"6:\n" STORE_C "6:\n" STORE_C


"101:\n" "101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory");


#undef STORE_C #undef STORE_C
#undef STORE_LINE #undef STORE_LINE
@@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
* Accumulator * Accumulator
*/ */
// clang-format on // clang-format on
static __attribute__((noinline)) void kern_4x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k, int remain_n) {
static __attribute__((noinline)) void kern_4x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4; K /= 4;
const int16_t* a_ptr = packA; const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB; const int8_t* b_ptr = packB;
@@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA,
"6:\n" STORE_C "6:\n" STORE_C


"101:\n" "101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory");


#undef STORE_C #undef STORE_C
#undef STORE_LINE #undef STORE_LINE
} }


static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
static void gemm_s8x8x16_mk4_16x12_pack_A(
dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 16; constexpr int pack_m = 16;
@@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr,
} }
} }


static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_16x12_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");


constexpr int pack_n = 12; constexpr int pack_n = 12;


+ 18
- 22
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h View File

@@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 {
*/ */


// clang-format on // clang-format on
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool, int remain_n) {
static inline void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool,
int remain_n) {
K = div_ceil(K, 8); K = div_ceil(K, 8);
int oddk = (K & 1); int oddk = (K & 1);
K = ((K + 1) / 2) - 1; K = ((K + 1) / 2) - 1;
@@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"7:\n" STORE_C "7:\n" STORE_C


"101:\n" "101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk),
[LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: :
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");


#undef STORE_C #undef STORE_C
#undef STORE_LINE #undef STORE_LINE
@@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) {
vst1_s8(outptr + 3 * 8, in0.val[3]); vst1_s8(outptr + 3 * 8, in0.val[3]);
} }


static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2,
dt_int8* outptr) {
static inline void interleve_8x4_b(
const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr); int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vld1q_s8(inptr2); int8x16_t in1 = vld1q_s8(inptr2);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
} }


static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr); int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vdupq_n_s8(0); int8x16_t in1 = vdupq_n_s8(0);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
} }


static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
int ldin, int m0, int mmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_4x4x8_pack_A(
dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 4; constexpr int pack_m = 4;
@@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
} }
} }


static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_4x4x8_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");


constexpr int pack_n = 4; constexpr int pack_n = 4;


+ 48
- 55
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h View File

@@ -18,7 +18,6 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul_mk4_8x8x8 { namespace matmul_mk4_8x8x8 {



/** /**
* Overview of register layout: * Overview of register layout:
* *
@@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 {
* | v16 | | v28 | * | v16 | | v28 |
* | v17 | | v29 | * | v17 | | v29 |
* | v16 | | v30 | * | v16 | | v30 |
* | v17 | | v31 |
* | v17 | | v31 |
* +--------+ - - - - +---------------------------------+ * +--------+ - - - - +---------------------------------+
* *
* Accumulator * Accumulator
*/ */
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off // clang-format off
#define LOAD_C_8 \ #define LOAD_C_8 \
"ld1 {v0.8h}, [x0], #16\n" \ "ld1 {v0.8h}, [x0], #16\n" \
@@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31"); "v29", "v30", "v31");
// clang-format on
// clang-format on
} }


static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_8x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB; const int8_t* a_ptr = packB;
const int8_t* b_ptr = packA; const int8_t* b_ptr = packA;
// clang-format off
// clang-format off
register int16_t* outptr asm("x0") = output; register int16_t* outptr asm("x0") = output;
asm volatile( asm volatile(
"add x1, x0, %x[LDC]\n" "add x1, x0, %x[LDC]\n"
@@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 1b\n" "cbnz %w[K], 1b\n"


"cmp %w[is_first_k], #1\n" "cmp %w[is_first_k], #1\n"
"beq 2f\n"
"beq 2f\n"
"cmp %x[m_remain], #8 \n" "cmp %x[m_remain], #8 \n"
"beq 8f \n" "beq 8f \n"
"cmp %x[m_remain], #4 \n" "cmp %x[m_remain], #4 \n"
@@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"zip2 v15.2d, v30.2d, v31.2d \n" "zip2 v15.2d, v30.2d, v31.2d \n"
"add v6.8h, v6.8h, v13.8h \n" "add v6.8h, v6.8h, v13.8h \n"
"add v7.8h, v7.8h, v15.8h \n" "add v7.8h, v7.8h, v15.8h \n"
//save to memory
// save to memory
"cmp %x[m_remain], #8 \n" "cmp %x[m_remain], #8 \n"
"beq 4f \n" "beq 4f \n"
"cmp %x[m_remain], #4 \n" "cmp %x[m_remain], #4 \n"
@@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"b 1000f \n" "b 1000f \n"


"1000: \n" "1000: \n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
: :
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
// clang-format on


#undef LOAD_C_8 #undef LOAD_C_8
#undef STORE_C_8 #undef STORE_C_8
} }



static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off // clang-format off
#define LOAD_C_4 \ #define LOAD_C_4 \
"ld1 {v0.8h}, [x0], #16\n" \ "ld1 {v0.8h}, [x0], #16\n" \
@@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
#undef LOAD_C_4 #undef LOAD_C_4
#undef STORE_C_4 #undef STORE_C_4
} }
static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8; K /= 8;
LDC = LDC * sizeof(int16_t); LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
// clang-format off
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off
register int16_t* outptr asm("x0") = output; register int16_t* outptr asm("x0") = output;
asm volatile( asm volatile(


@@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C_4 #undef STORE_C_4
} }



//! pack to icxoc //! pack to icxoc
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7))
//! if M K is not times of 8,pack 0 instead
static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
//! if M K is not times of 8,pack 0 instead
static void gemm_s8x8x16_mk4_8x8x8_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 8; constexpr int pack_m = 8;
@@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
prefetch_2x(inptr0); prefetch_2x(inptr0);
prefetch_2x(inptr1); prefetch_2x(inptr1);
int k_idx = k0; int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
for (; k_idx + 7 < kmax; k_idx += pack_k) {
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
} }


if (k_idx < kmax) { if (k_idx < kmax) {
@@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
prefetch_2x(inptr0); prefetch_2x(inptr0);
prefetch_2x(inptr1); prefetch_2x(inptr1);
int k_idx = k0; int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
for (; k_idx + 7 < kmax; k_idx += pack_k) {
inptr1 = zerobuff; inptr1 = zerobuff;
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
} }


if (k_idx < kmax) { if (k_idx < kmax) {
@@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
} }
//! pack to nxic //! pack to nxic
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead.
static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_8x8x8_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");


constexpr int pack_n = 8; constexpr int pack_n = 8;
@@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
int8_t tmpbuff0[pack_n * pack_size] = {0}; int8_t tmpbuff0[pack_n * pack_size] = {0};
int8_t tmpbuff1[pack_n * pack_size] = {0}; int8_t tmpbuff1[pack_n * pack_size] = {0};
int8_t zerobuff[pack_n * pack_size] = {0}; int8_t zerobuff[pack_n * pack_size] = {0};
const int ksize = round_up<int>((kmax - k0),8);
const int ksize = round_up<int>((kmax - k0), 8);
const int nsize = nmax - n0; const int nsize = nmax - n0;
const int n_end = nsize / pack_n * pack_n + n0; const int n_end = nsize / pack_n * pack_n + n0;
const int remain_n = nsize % pack_n; const int remain_n = nsize % pack_n;
int output_stride = ksize * pack_n; int output_stride = ksize * pack_n;
int8_t* outptr_base = out; int8_t* outptr_base = out;
int k_idx = k0; int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
for (; k_idx + 7 < kmax; k_idx += pack_k) {
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = inptr0 + ldin; const int8_t* inptr1 = inptr0 + ldin;
prefetch_3x(inptr0); prefetch_3x(inptr0);
@@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
auto outptr = outptr_base; auto outptr = outptr_base;
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
transpose_8x8_mk4_b(inptr0, inptr1, outptr); transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
outptr += output_stride;
} }
if (remain_n > 0) { if (remain_n > 0) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
@@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
} }
outptr_base += pack_n * pack_k; outptr_base += pack_n * pack_k;
} }
if(k_idx < kmax){
if (k_idx < kmax) {
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = nullptr; const int8_t* inptr1 = nullptr;
prefetch_3x(inptr0); prefetch_3x(inptr0);
@@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
} }
} }


} // namespace matmul_mk4_16x12x4_a53
} // namespace matmul_mk4_8x8x8
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn




+ 128
- 142
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp View File

@@ -10,13 +10,13 @@
* implied. * implied.
*/ */


#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"
@@ -28,39 +28,35 @@ using namespace aarch64::matmul;
// ===========================gemm_s8x8x16_4x4================================== // ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8);


void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0,
ymax, k0, kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(
out, in, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8x8x16_8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
size_t n = 0; size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K8; cur_packB += K8;
} }


for (; n < N; n += 4) { for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4; output += 4;
cur_packB += K4; cur_packB += K4;
} }
@@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_4x4================================== // ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4);


void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_4x4::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax);
} else { } else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
} }
} }


void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) { if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else { } else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} }
} }


void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8x8x16_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype); MEGDNN_MARK_USED_VAR(C_dtype);
@@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
B_INTERLEAVE);
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }


for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }
@@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0; size_t n = 0;
const dt_int8* cur_packB = packB; const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) { for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE; output += B_INTERLEAVE;
cur_packB += K4; cur_packB += K4;
} }
@@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_mk4_16x12================================== // ===========================gemm_s8x8x16_mk4_16x12==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53);


void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
void gemm_s8x8x16_mk4_16x12_a53::pack_A(
dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(
out, in, ldin, y0, ymax, k0, kmax);
} }


void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
void gemm_s8x8x16_mk4_16x12_a53::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(
out, in, ldin, x0, xmax, k0, kmax);
} }


void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
const dt_int8* packB, size_t M, size_t N,
size_t K, dt_int16* C, size_t LDC,
bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_16x12_a53::kern(
const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k"); megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
@@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0; size_t n_idx = 0;
const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) { for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_16x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_16x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
@@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0; size_t n_idx = 0;
const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) { for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_8x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_8x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
@@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0; size_t n_idx = 0;
const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) { for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
@@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
// ===========================gemm_s8x8x16_mk4_4x4_a72================================== // ===========================gemm_s8x8x16_mk4_4x4_a72==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72);


void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax,
k0, kmax);
void gemm_s8x8x16_mk4_4x4_a72::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(
out, in, ldin, y0, ymax, k0, kmax);
} }


void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax,
k0, kmax);
void gemm_s8x8x16_mk4_4x4_a72::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(
out, in, ldin, x0, xmax, k0, kmax);
} }


void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k,
const dt_int16*, dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_4x4_a72::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k"); megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
@@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,


const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_4x4x8_a72::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += pack_n * packed_k; cur_packB += pack_n * packed_k;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_4x4x8_a72::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += pack_n * packed_k; cur_packB += pack_n * packed_k;
} }
@@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_mk4_8x8x8================================== // ===========================gemm_s8x8x16_mk4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8);


void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
void gemm_s8x8x16_mk4_8x8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax);
} }


void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
void gemm_s8x8x16_mk4_8x8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax);
} }


void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_8x8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k"); megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(B_dtype);
@@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n_idx = 0; size_t n_idx = 0;
const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) { for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, pack_n);
matmul_mk4_8x8x8::kern_8x8(
packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += KSIZE8; cur_packB += KSIZE8;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, remain_n);
matmul_mk4_8x8x8::kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += KSIZE8; cur_packB += KSIZE8;
} }
@@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n_idx = 0; size_t n_idx = 0;
const int8_t* cur_packB = packB; const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) { for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, 4, pack_n);
matmul_mk4_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n);
output += pack_n * pack_size; output += pack_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }
if (remain_n > 0) { if (remain_n > 0) {
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4, remain_n);
matmul_mk4_8x8x8::kern_4x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n);
output += remain_n * pack_size; output += remain_n * pack_size;
cur_packB += pack_n * K; cur_packB += pack_n * K;
} }


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save