Browse Source

style(all): reformat c++ code

GitOrigin-RevId: 3ffd1b211f
release-1.7
Megvii Engine Team 3 years ago
parent
commit
369c2ccc5a
100 changed files with 6469 additions and 6810 deletions
  1. +0
    -0
      dnn/atlas-stub/include/acl/acl.h
  2. +0
    -0
      dnn/atlas-stub/include/acl/acl_base.h
  3. +0
    -0
      dnn/atlas-stub/include/acl/acl_mdl.h
  4. +0
    -0
      dnn/atlas-stub/include/acl/acl_op.h
  5. +0
    -0
      dnn/atlas-stub/include/acl/acl_rt.h
  6. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_cblas.h
  7. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_dvpp.h
  8. +0
    -0
      dnn/atlas-stub/include/acl/ops/acl_fv.h
  9. +2
    -2
      dnn/include/hip_header.h
  10. +52
    -66
      dnn/include/megcore.h
  11. +5
    -6
      dnn/include/megcore_atlas.h
  12. +2
    -3
      dnn/include/megcore_cambricon.h
  13. +1
    -2
      dnn/include/megcore_cdefs.h
  14. +3
    -4
      dnn/include/megcore_cuda.h
  15. +6
    -5
      dnn/include/megcore_rocm.h
  16. +1
    -1
      dnn/include/megdnn.h
  17. +73
    -74
      dnn/include/megdnn/arch.h
  18. +33
    -42
      dnn/include/megdnn/basic_types.h
  19. +6
    -7
      dnn/include/megdnn/common.h
  20. +4
    -4
      dnn/include/megdnn/cuda.h
  21. +353
    -463
      dnn/include/megdnn/dtype.h
  22. +18
    -13
      dnn/include/megdnn/dtype/half_common_epilogue.h
  23. +147
    -144
      dnn/include/megdnn/dtype/half_common_prologue.h
  24. +134
    -137
      dnn/include/megdnn/handle.h
  25. +3
    -2
      dnn/include/megdnn/heuristic_cache.h
  26. +5
    -6
      dnn/include/megdnn/internal/defs.h
  27. +20
    -19
      dnn/include/megdnn/internal/opr_header_prologue.h
  28. +0
    -1
      dnn/include/megdnn/internal/visibility_epilogue.h
  29. +16
    -20
      dnn/include/megdnn/opr_result_defs.h
  30. +2
    -4
      dnn/include/megdnn/oprs.h
  31. +91
    -124
      dnn/include/megdnn/oprs/base.h
  32. +123
    -99
      dnn/include/megdnn/oprs/cv.h
  33. +847
    -822
      dnn/include/megdnn/oprs/general.h
  34. +164
    -180
      dnn/include/megdnn/oprs/imgproc.h
  35. +55
    -54
      dnn/include/megdnn/oprs/linalg.h
  36. +629
    -618
      dnn/include/megdnn/oprs/nn.h
  37. +5
    -8
      dnn/include/megdnn/oprs/nn_int.h
  38. +103
    -87
      dnn/include/megdnn/oprs/utils.h
  39. +36
    -47
      dnn/include/megdnn/tensor_format.h
  40. +3
    -5
      dnn/include/megdnn/tensor_iter.h
  41. +5
    -5
      dnn/include/megdnn/thin/function.h
  42. +42
    -76
      dnn/include/megdnn/thin/small_vector.h
  43. +6
    -6
      dnn/include/megdnn/version.h
  44. +26
    -24
      dnn/src/aarch64/conv_bias/fp16/algos.cpp
  45. +5
    -5
      dnn/src/aarch64/conv_bias/fp16/algos.h
  46. +48
    -67
      dnn/src/aarch64/conv_bias/fp16/stride2_kern.h
  47. +30
    -31
      dnn/src/aarch64/conv_bias/fp32/algos.cpp
  48. +5
    -5
      dnn/src/aarch64/conv_bias/fp32/algos.h
  49. +33
    -38
      dnn/src/aarch64/conv_bias/fp32/stride2_kern.h
  50. +36
    -37
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  51. +6
    -8
      dnn/src/aarch64/conv_bias/int8/algos.h
  52. +61
    -70
      dnn/src/aarch64/conv_bias/int8/strategy.cpp
  53. +25
    -26
      dnn/src/aarch64/conv_bias/int8/strategy.h
  54. +12
    -13
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  55. +1
    -1
      dnn/src/aarch64/conv_bias/opr_impl.h
  56. +36
    -37
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  57. +6
    -8
      dnn/src/aarch64/conv_bias/quint8/algos.h
  58. +93
    -97
      dnn/src/aarch64/conv_bias/quint8/strategy.cpp
  59. +26
    -28
      dnn/src/aarch64/conv_bias/quint8/strategy.h
  60. +4
    -4
      dnn/src/aarch64/handle.cpp
  61. +10
    -12
      dnn/src/aarch64/handle.h
  62. +327
    -438
      dnn/src/aarch64/matrix_mul/algos.cpp
  63. +28
    -76
      dnn/src/aarch64/matrix_mul/algos.h
  64. +393
    -458
      dnn/src/aarch64/matrix_mul/asm/common.h
  65. +534
    -555
      dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
  66. +4
    -4
      dnn/src/aarch64/matrix_mul/fp16/strategy.h
  67. +21
    -20
      dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp
  68. +13
    -11
      dnn/src/aarch64/matrix_mul/fp32/common.h
  69. +124
    -118
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h
  70. +54
    -60
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
  71. +34
    -37
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h
  72. +34
    -37
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h
  73. +28
    -27
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
  74. +23
    -22
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h
  75. +23
    -22
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h
  76. +83
    -86
      dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  77. +5
    -8
      dnn/src/aarch64/matrix_mul/fp32/strategy.h
  78. +19
    -23
      dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp
  79. +99
    -96
      dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h
  80. +35
    -36
      dnn/src/aarch64/matrix_mul/int16/strategy.cpp
  81. +4
    -4
      dnn/src/aarch64/matrix_mul/int16/strategy.h
  82. +23
    -21
      dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp
  83. +121
    -103
      dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h
  84. +33
    -34
      dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp
  85. +2
    -2
      dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h
  86. +57
    -39
      dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
  87. +163
    -113
      dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
  88. +47
    -41
      dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
  89. +77
    -82
      dnn/src/aarch64/matrix_mul/int8/strategy.cpp
  90. +6
    -6
      dnn/src/aarch64/matrix_mul/int8/strategy.h
  91. +62
    -59
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
  92. +30
    -32
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
  93. +59
    -59
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
  94. +5
    -5
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
  95. +54
    -38
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h
  96. +159
    -111
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h
  97. +34
    -41
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h
  98. +18
    -22
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h
  99. +48
    -55
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h
  100. +128
    -142
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp

+ 0
- 0
dnn/atlas-stub/include/acl/acl.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_base.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_mdl.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_op.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/acl_rt.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_cblas.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_dvpp.h View File


+ 0
- 0
dnn/atlas-stub/include/acl/ops/acl_fv.h View File


+ 2
- 2
dnn/include/hip_header.h View File

@@ -23,9 +23,9 @@
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Wsign-compare"
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#pragma GCC diagnostic pop

#if !defined(__HIP_PLATFORM_HCC__)


+ 52
- 66
dnn/include/megcore.h View File

@@ -11,10 +11,10 @@

#pragma once

#include "megdnn/thin/function.h"
#include "megcore_cdefs.h"
#include <cstddef>
#include <memory>
#include "megcore_cdefs.h"
#include "megdnn/thin/function.h"

#include "megdnn/internal/visibility_prologue.h"

@@ -26,36 +26,35 @@ namespace megcore {
* the caller thread immediately.
*/
class CPUDispatcher {
public:
using Task = megdnn::thin_function<void()>;
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>;
virtual ~CPUDispatcher() noexcept;

/*!
* \brief dispatch a task on the computing thread
* \param task the task that would be moved away
*/
virtual void dispatch(Task&& task) = 0;

/*!
* \brief dispatch a multithreading task on the computing thread
* \param task the task would be moved away
* \param parallelism the parallelism of the task.
*/
virtual void dispatch(MultiThreadingTask&& task,
size_t parallelism) = 0;

/*!
* \brief synchronize the calling thread with the computing thread
*/
virtual void sync() = 0;

/*!
* \brief the computing thread number.
*/
virtual size_t nr_threads() = 0;
public:
using Task = megdnn::thin_function<void()>;
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>;
virtual ~CPUDispatcher() noexcept;

/*!
* \brief dispatch a task on the computing thread
* \param task the task that would be moved away
*/
virtual void dispatch(Task&& task) = 0;

/*!
* \brief dispatch a multithreading task on the computing thread
* \param task the task would be moved away
* \param parallelism the parallelism of the task.
*/
virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0;

/*!
* \brief synchronize the calling thread with the computing thread
*/
virtual void sync() = 0;

/*!
* \brief the computing thread number.
*/
virtual size_t nr_threads() = 0;
};
} // namespace megcore
} // namespace megcore

using MegcoreCPUDispatcher = megcore::CPUDispatcher;

@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher;
* \brief Layer 1: device handle
*/
struct megcoreDeviceContext;
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t;
typedef struct megcoreDeviceContext* megcoreDeviceHandle_t;

megcoreStatus_t megcoreCreateDeviceHandle(
megcoreDeviceHandle_t *handle,
megcorePlatform_t platform,
int deviceID = -1,
megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1,
unsigned int flags = 0);
megcoreStatus_t megcoreDestroyDeviceHandle(
megcoreDeviceHandle_t handle);

megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle,
megcorePlatform_t *platform);
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle,
int *deviceID);
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle,
size_t *memAlignmentInBytes);
megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle);

megcoreStatus_t megcoreGetPlatform(
megcoreDeviceHandle_t handle, megcorePlatform_t* platform);
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID);
megcoreStatus_t megcoreGetMemAlignment(
megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes);
megcoreStatus_t megcoreGetDeviceFlags(
megcoreDeviceHandle_t handle,
unsigned int *flags);
megcoreDeviceHandle_t handle, unsigned int* flags);

megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle);
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle,
void **devPtr, size_t sizeInBytes);
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle,
void *devPtr);
megcoreStatus_t megcoreMalloc(
megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes);
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr);

/**
* \brief Layer 2: computing handle
*/
struct megcoreComputingContext;
typedef struct megcoreComputingContext *megcoreComputingHandle_t;
typedef struct megcoreComputingContext* megcoreComputingHandle_t;

megcoreStatus_t megcoreCreateComputingHandle(
megcoreComputingHandle_t *compHandle,
megcoreDeviceHandle_t devHandle,
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags = 0);

megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher(
megcoreComputingHandle_t *compHandle,
megcoreDeviceHandle_t devHandle,
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher,
unsigned int flags = 0);

megcoreStatus_t megcoreDestroyComputingHandle(
megcoreComputingHandle_t handle);
megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle);

megcoreStatus_t megcoreGetDeviceHandle(
megcoreComputingHandle_t compHandle,
megcoreDeviceHandle_t *devHandle);
megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle);
megcoreStatus_t megcoreGetComputingFlags(
megcoreComputingHandle_t handle,
unsigned int *flags);
megcoreComputingHandle_t handle, unsigned int* flags);

MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle);

megcoreStatus_t megcoreMemcpy(
megcoreComputingHandle_t handle,
void *dst, const void *src, size_t sizeInBytes,
megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes,
megcoreMemcpyKind_t kind);
megcoreStatus_t megcoreMemset(
megcoreComputingHandle_t handle,
void *dst, int value, size_t sizeInBytes);
megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes);
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle);

/**
* \brief Miscellaneous
*/
const char *megcoreGetErrorName(megcoreStatus_t status);
const char* megcoreGetErrorName(megcoreStatus_t status);

#include "megdnn/internal/visibility_epilogue.h"



+ 5
- 6
dnn/include/megcore_atlas.h View File

@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const AtlasContext& ctx);

megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle,
AtlasContext* ctx);
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx);

namespace atlas {
//! convert acl error code to error string
@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, aclrtStream stream) {
megcore::AtlasContext ctx{stream};
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithAtlasContext(
compHandle, devHandle, flags, ctx);
}

inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle,
aclrtStream* stream) {
inline megcoreStatus_t megcoreGetACLStream(
megcoreComputingHandle_t handle, aclrtStream* stream) {
megcore::AtlasContext ctx;
auto ret = megcore::getAtlasContext(handle, &ctx);
*stream = ctx.stream;


+ 2
- 3
dnn/include/megcore_cambricon.h View File

@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CambriconContext& ctx);

megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle,
CambriconContext* ctx);
megcoreStatus_t getCambriconContext(
megcoreComputingHandle_t handle, CambriconContext* ctx);

} // namespace megcore

@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue(
#include "megdnn/internal/visibility_epilogue.h"

// vim: syntax=cpp.doxygen


+ 1
- 2
dnn/include/megcore_cdefs.h View File

@@ -40,7 +40,6 @@ typedef enum {
megcoreErrorInternalError = 5,
} megcoreStatus_t;


/**
* \brief Memcpy kind
*/
@@ -70,6 +69,6 @@ struct AsyncErrorInfo {
char msg[228];
int msg_args[4];
};
} // namespace megcore
} // namespace megcore

// vim: syntax=cpp.doxygen

+ 3
- 4
dnn/include/megcore_cuda.h View File

@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const CudaContext& ctx);

megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle,
CudaContext* ctx);
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx);

} // namespace megcore

@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream(
unsigned int flags, cudaStream_t stream) {
megcore::CudaContext ctx;
ctx.stream = stream;
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithCUDAContext(
compHandle, devHandle, flags, ctx);
}

static inline megcoreStatus_t megcoreGetCUDAStream(


+ 6
- 5
dnn/include/megcore_rocm.h View File

@@ -23,7 +23,9 @@ struct ROCMContext {
hipStream_t stream = nullptr;

static std::atomic_bool sm_miopen_algo_search;
static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); }
static inline bool enable_miopen_algo_search() {
return sm_miopen_algo_search.load();
}
static inline void enable_miopen_algo_search(bool enable_algo_search) {
sm_miopen_algo_search.store(enable_algo_search);
}
@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext(
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
unsigned int flags, const ROCMContext& ctx);

megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle,
ROCMContext* ctx);
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx);

// Set MIOpen algo search enabled or disabled
megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true);
@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream(
unsigned int flags, hipStream_t stream) {
megcore::ROCMContext ctx;
ctx.stream = stream;
return megcore::createComputingHandleWithROCMContext(compHandle, devHandle,
flags, ctx);
return megcore::createComputingHandleWithROCMContext(
compHandle, devHandle, flags, ctx);
}

static inline megcoreStatus_t megcoreGetROCMStream(


+ 1
- 1
dnn/include/megdnn.h View File

@@ -10,7 +10,7 @@
*/
#pragma once

#include "megdnn/version.h"
#include "megdnn/oprs.h"
#include "megdnn/version.h"

// vim: syntax=cpp.doxygen

+ 73
- 74
dnn/include/megdnn/arch.h View File

@@ -14,20 +14,20 @@
#include "megdnn/config/config.h"

#if defined(__GNUC__) || defined(__clang__)
#if !defined (__clang__)
// gcc specific
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#if GCC_VERSION < 40800
#error "GCC version should be at least 4.8.0."
#endif // GCC_VERSION < 40800
#endif // !defined(__clang__)
#if !defined(__clang__)
// gcc specific
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#if GCC_VERSION < 40800
#error "GCC version should be at least 4.8.0."
#endif // GCC_VERSION < 40800
#endif // !defined(__clang__)

#ifndef megdnn_trap
#define megdnn_trap() __builtin_trap()
#endif
#ifndef megdnn_trap
#define megdnn_trap() __builtin_trap()
#endif

#define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0)
#define megdnn_likely(v) __builtin_expect(bool(v), 1)
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0)

#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG)
//! Thumb2 limit code length
@@ -36,123 +36,122 @@
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__))
#endif

#define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_NORETURN __attribute__((noreturn))
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#if defined(__clang_major__) && (__clang_major__ >= 7)
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#else
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]]
#endif
#define MEGDNN_NOINLINE __attribute__((noinline))
#define megdnn_isatty(x) isatty(x)
#define MEGDNN_DEPRECATED __attribute__((deprecated))
#define MEGDNN_PACKED __attribute__((packed))
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_NORETURN __attribute__((noreturn))
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#if defined(__clang_major__) && (__clang_major__ >= 7)
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd)))
#else
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]]
#endif
#define MEGDNN_NOINLINE __attribute__((noinline))
#define megdnn_isatty(x) isatty(x)
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER)

#ifndef megdnn_trap
#define megdnn_trap() __debugbreak()
#endif

#define megdnn_likely(v) (bool(v))
#define megdnn_likely(v) (bool(v))
#define megdnn_unlikely(v) (bool(v))

#define MEGDNN_DEPRECATED
#define MEGDNN_PACKED
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_CONSTEXPR constexpr
#define MEGDNN_NOEXCEPT noexcept
#define MEGDNN_STATIC_ASSERT static_assert
#define MEGDNN_FINAL final
#define MEGDNN_FINAL final

#if defined(_MSC_VER)
#define MEGDNN_NORETURN __declspec(noreturn)
#define MEGDNN_NOINLINE __declspec(noinline)
#define MEGDNN_NORETURN __declspec(noreturn)
#define MEGDNN_NOINLINE __declspec(noinline)
#else
#define MEGDNN_NORETURN
#define MEGDNN_FORCE_NOINLINE
#endif // _MSC_VER
#define MEGDNN_NORETURN
#define MEGDNN_FORCE_NOINLINE
#endif // _MSC_VER

#define MEGDNN_WARN_UNUSED_RESULT

#define megdnn_isatty(x) _isatty(x)

#else
#error "unknown compiler"
#endif // __GNUC__
#error "unknown compiler"
#endif // __GNUC__

// __cpp_exceptions and __cpp_rtti is referred from
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
// similar for __GXX_RTTI
// _CPPUNWIND and _CPPRTTI is used by MSVC, see
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019
#ifndef MEGDNN_ENABLE_EXCEPTIONS
#if __cpp_exceptions || __EXCEPTIONS || \
(defined(_MSC_VER) && defined(_CPPUNWIND))
#define MEGDNN_ENABLE_EXCEPTIONS 1
#else
#define MEGDNN_ENABLE_EXCEPTIONS 0
#endif
#if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND))
#define MEGDNN_ENABLE_EXCEPTIONS 1
#else
#define MEGDNN_ENABLE_EXCEPTIONS 0
#endif
#endif
#ifndef MEGDNN_ENABLE_RTTI
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI))
#define MEGDNN_ENABLE_RTTI 1
#else
#define MEGDNN_ENABLE_RTTI 0
#endif
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI))
#define MEGDNN_ENABLE_RTTI 1
#else
#define MEGDNN_ENABLE_RTTI 0
#endif
#endif

#ifdef __CUDACC__
#define MEGDNN_CC_CUDA 1
#undef MEGDNN_CONSTEXPR
#define MEGDNN_CONSTEXPR const
#define MEGDNN_CC_CUDA 1
#undef MEGDNN_CONSTEXPR
#define MEGDNN_CONSTEXPR const

#if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ >= 9
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg);
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg);
#else
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg)
#undef MEGDNN_STATIC_ASSERT
#define MEGDNN_STATIC_ASSERT(cond, msg)
#endif
#endif

#define nullptr NULL
#undef MEGDNN_FINAL
#define MEGDNN_FINAL
#define nullptr NULL
#undef MEGDNN_FINAL
#define MEGDNN_FINAL
#elif defined(__HIPCC__)
#define MEGDNN_CC_CUDA 1
#define MEGDNN_CC_CUDA 1
#else
#define MEGDNN_CC_HOST 1
#endif // __CUDACC__
#define MEGDNN_CC_HOST 1
#endif // __CUDACC__

// MEGDNN_HOST and MEGDNN_DEVICE
#if MEGDNN_CC_CUDA
#define MEGDNN_HOST __host__
#define MEGDNN_DEVICE __device__
#define MEGDNN_HOST __host__
#define MEGDNN_DEVICE __device__
#else
#define MEGDNN_HOST
#define MEGDNN_DEVICE
#define MEGDNN_HOST
#define MEGDNN_DEVICE
#endif

#if MEGDNN_CC_CUDA
#define MEGDNN_FORCE_INLINE __forceinline__
#define MEGDNN_FORCE_INLINE __forceinline__
#else
#if __GNUC__ || __has_attribute(always_inline)
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#else
#define MEGDNN_FORCE_INLINE inline
#define MEGDNN_FORCE_INLINE inline
#endif
#endif

#if defined(_MSC_VER) || defined(WIN32)
#define ATTR_ALIGNED(v) __declspec(align(v))
#define ATTR_ALIGNED(v) __declspec(align(v))
#else
#define ATTR_ALIGNED(v) __attribute__((aligned(v)))
#define ATTR_ALIGNED(v) __attribute__((aligned(v)))
#endif
// vim: syntax=cpp.doxygen

+ 33
- 42
dnn/include/megdnn/basic_types.h View File

@@ -16,10 +16,10 @@
#include "megdnn/internal/defs.h"

#if MEGDNN_CC_HOST
#include <cstdarg>
#include <string>
#include <type_traits>
#include <vector>
#include <cstdarg>
#include "megdnn/thin/small_vector.h"
#endif // MEGDNN_CC_HOST

@@ -35,8 +35,7 @@ class ErrorHandler {
protected:
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0;

MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(
const std::string& msg) {
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) {
on_megdnn_error(msg);
}

@@ -70,8 +69,9 @@ public:
#if MEGDNN_CC_HOST
enum class LogLevel { DEBUG, INFO, WARN, ERROR };

typedef void (*LogHandler)(LogLevel level, const char* file, const char* func,
int line, const char* fmt, va_list ap);
typedef void (*LogHandler)(
LogLevel level, const char* file, const char* func, int line, const char* fmt,
va_list ap);

/*!
* \brief set the callback to receive all log messages
@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape {
ptrdiff_t low_elem, low_byte;
size_t high_elem, high_byte;

Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem,
size_t high_byte)
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte)
: low_elem(low_elem),
low_byte(low_byte),
high_elem(high_elem),
@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape {
TensorLayout(const TensorShape& shape, DType dtype, Format format);

//! creating layout with user-specified shape and stride.
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype);
TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype);

TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride,
DType dtype, Format format);
TensorLayout(
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype,
Format format);

/* =================== inplace modifiers =================== */

@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape {
*
* \throw TensorReshapeError if no stride exists for target shape.
*/
TensorLayout reshape(const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;
TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;

/*!
* \brief try to reshape to another view; return whether these two shapes
@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape {
* \return true iff there exists target stride so this layout can be
* converted to target shape and the elements can match.
*/
bool try_reshape(TensorLayout& output,
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;
bool try_reshape(TensorLayout& output, const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;

/*!
* \brief Broadcast on dims with shape == 1 to match target *shape*.
* \throw TensorReshapeError if could not be satisfied
*/
TensorLayout broadcast(const TensorShape& shape) const
MEGDNN_WARN_UNUSED_RESULT;
TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT;

/*!
* \brief Collapse consecutive axes with contiguous layout together
@@ -441,8 +440,7 @@ struct Workspace {

Workspace() : raw_ptr(NULL), size(0) {}

Workspace(dt_byte* raw_ptr_, size_t size_)
: raw_ptr(raw_ptr_), size(size_) {}
Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {}

template <typename T>
T* ptr(size_t offset_in_bytes = 0) const {
@@ -467,9 +465,8 @@ public:
* \param shape requested output shape
* \param user_data extra user data passed in DynOutMallocPolicyCall
*/
virtual TensorND alloc_output(size_t id, DType dtype,
const TensorShape& shape,
void* user_data) = 0;
virtual TensorND alloc_output(
size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0;

/*!
* \brief allocate workspace memory
@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall {
*/
template <typename T = void, typename elem = T>
T* alloc_workspace(size_t nr_elem) {
using real_elem =
typename std::conditional<std::is_same<elem, void>::value,
uint8_t, elem>::type;
return static_cast<T*>(policy->alloc_workspace(
nr_elem * sizeof(real_elem), user_data));
using real_elem = typename std::conditional<
std::is_same<elem, void>::value, uint8_t, elem>::type;
return static_cast<T*>(
policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data));
}

void free_workspace(void* ptr) {
return policy->free_workspace(ptr, user_data);
}
void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); }
};


template <typename T>
class EnumClassBit {
std::underlying_type_t<T> m_val;
@@ -528,8 +521,7 @@ class EnumClassBit {
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {}

public:
constexpr EnumClassBit(T v)
: m_val(static_cast<std::underlying_type_t<T>>(v)) {}
constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {}

constexpr operator T() const { return static_cast<T>(m_val); }

@@ -542,7 +534,7 @@ public:

DEF_OPR(&)
DEF_OPR(|)
DEF_OPR (^)
DEF_OPR(^)

constexpr EnumClassBit operator~() const { return ~m_val; }

@@ -553,14 +545,13 @@ public:

} // namespace megdnn

#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) \
op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \
return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \
} \
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \
::megdnn::EnumClassBit<cls> x, cls y) { \
return x op ::megdnn::EnumClassBit<cls>(y); \
}

#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \


+ 6
- 7
dnn/include/megdnn/common.h View File

@@ -14,14 +14,14 @@
#include "megbrain_build_config.h"

#if MGB_ENABLE_GETENV
#define MGB_GETENV ::std::getenv
#define MGB_GETENV ::std::getenv
#else
#define MGB_GETENV(_name) static_cast<char*>(nullptr)
#define MGB_GETENV(_name) static_cast<char*>(nullptr)
#endif

#ifdef WIN32
#define unsetenv(_name) _putenv_s(_name, "");
#define setenv(name,value,overwrite) _putenv_s(name,value)
#define unsetenv(_name) _putenv_s(_name, "");
#define setenv(name, value, overwrite) _putenv_s(name, value)
#endif

namespace megdnn {
@@ -32,8 +32,7 @@ namespace megdnn {
*/
template <class Opr, typename... Args>
bool has_available_algo(Opr* opr, Args&&... args) {
const typename Opr::AlgoBase::SizeArgs size_args(
opr, std::forward<Args>(args)...);
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
for (auto i : Opr::algo_pack().all_algos) {
if (i->is_available(size_args)) {
return true;
@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) {
return false;
}

}
} // namespace megdnn

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 4
dnn/include/megdnn/cuda.h View File

@@ -17,11 +17,11 @@
#include "megdnn/internal/visibility_prologue.h"
namespace megdnn {

std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream,
int device_id = -1);
cudaStream_t get_cuda_stream(Handle *handle);
std::unique_ptr<Handle> make_cuda_handle_with_stream(
cudaStream_t stream, int device_id = -1);
cudaStream_t get_cuda_stream(Handle* handle);

} // namespace megdnn
} // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h"

// vim: syntax=cpp.doxygen

+ 353
- 463
dnn/include/megdnn/dtype.h
File diff suppressed because it is too large
View File


+ 18
- 13
dnn/include/megdnn/dtype/half_common_epilogue.h View File

@@ -3,17 +3,22 @@
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* Permission is hereby granted, free of charge, to any person obtaining a copy of this
* software and associated documentation files (the "Software"), to deal in the Software
* without restriction, including without limitation the rights to use, copy, modify,
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be included in all copies
* or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
@@ -41,8 +46,8 @@
#undef HALF_NOEXCEPT
#undef HALF_NOTHROW
#ifdef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#pragma warning(pop)
#undef HALF_POP_WARNINGS
#endif

// vim: syntax=cpp.doxygen

+ 147
- 144
dnn/include/megdnn/dtype/half_common_prologue.h View File

@@ -3,17 +3,22 @@
*
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
* Permission is hereby granted, free of charge, to any person obtaining a copy of this
* software and associated documentation files (the "Software"), to deal in the Software
* without restriction, including without limitation the rights to use, copy, modify,
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
* The above copyright notice and this permission notice shall be included in all copies
* or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* Version 1.11.0
* \file
@@ -39,166 +44,164 @@
#include "megdnn/arch.h"

/// Combined gcc version number.
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)
#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__)

//check C++11 language features
#if defined(__clang__) //clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER) //Intel C++
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ????????
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ????????
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ????????
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ????????
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif*/
#elif defined(__GNUC__) //gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) //Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned
// check C++11 language features
#if defined(__clang__) // clang
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \
!defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
/*#elif defined(__INTEL_COMPILER)
//Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >=
1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define
HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 &&
!defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define
HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 &&
!defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define
HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/
#elif defined(__GNUC__) // gcc
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
#define HALF_ENABLE_CPP11_CONSTEXPR 1
#endif
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
#define HALF_ENABLE_CPP11_NOEXCEPT 1
#endif
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
#define HALF_ENABLE_CPP11_USER_LITERALS 1
#endif
#if !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#endif
#elif defined(_MSC_VER) // Visual C++
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
#endif
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
#define HALF_ENABLE_CPP11_LONG_LONG 1
#endif
#define HALF_POP_WARNINGS 1
#pragma warning(push)
//! 4521 and 4522 is multiple copy/assigment operator specified
#pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if,
// negative unsigned
#endif

//check C++11 library features
// check C++11 library features
#include <utility>
#if defined(_LIBCPP_VERSION) //libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) //libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#if defined(_LIBCPP_VERSION) // libc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#elif defined(__GLIBCXX__) // libstdc++
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
#ifdef __clang__
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#else
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#endif
#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++
#if _CPPLIB_VER >= 520
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
#endif
#ifndef HALF_ENABLE_CPP11_CSTDINT
#define HALF_ENABLE_CPP11_CSTDINT 1
#endif
#ifndef HALF_ENABLE_CPP11_HASH
#define HALF_ENABLE_CPP11_HASH 1
#endif
#endif
#if _CPPLIB_VER >= 610
#ifndef HALF_ENABLE_CPP11_CMATH
#define HALF_ENABLE_CPP11_CMATH 1
#endif
#endif
#endif
#undef HALF_GNUC_VERSION

//support constexpr
// support constexpr
#if HALF_ENABLE_CPP11_CONSTEXPR
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#define HALF_CONSTEXPR constexpr
#define HALF_CONSTEXPR_CONST constexpr
#else
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#define HALF_CONSTEXPR
#define HALF_CONSTEXPR_CONST const
#endif

//support noexcept
// support noexcept
#if HALF_ENABLE_CPP11_NOEXCEPT
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#define HALF_NOEXCEPT noexcept
#define HALF_NOTHROW noexcept
#else
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#define HALF_NOEXCEPT
#define HALF_NOTHROW throw()
#endif

#include <algorithm>
#include <limits>
#include <climits>
#include <cmath>
#include <cstring>
#include <ostream>
#include <istream>
#include <limits>
#include <ostream>
#if HALF_ENABLE_CPP11_TYPE_TRAITS
#include <type_traits>
#include <type_traits>
#endif
#if HALF_ENABLE_CPP11_CSTDINT
#include <cstdint>
#include <cstdint>
#endif
#if HALF_ENABLE_CPP11_HASH
#include <functional>
#include <functional>
#endif

// vim: syntax=cpp.doxygen

+ 134
- 137
dnn/include/megdnn/handle.h View File

@@ -12,8 +12,8 @@
#pragma once

#include "megcore.h"
#include "megdnn/config/config.h"
#include "megdnn/basic_types.h"
#include "megdnn/config/config.h"

#include <functional>
#include <memory>
@@ -24,150 +24,147 @@ namespace megdnn {
class OperatorBase;

class Handle {
public:
enum class HandleType {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
ROCM = 11,
ATLAS = 13,
CAMBRICON = 12,
};

//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};

protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type);

public:
/**
* \brief Create a MegDNN handle from a MegCore Computing handle.
*
* \param[in] computing_handle MegCore computing handle. Please note
* that computing_handle would not be released when this Handle is
* destructed
* \param[in] debug_level
* Applicable for CPU computing handle.
* 0 means taking the fastest possible code path; it may contains
* platform-specific instructions such as SSE for x86_64 or NEON for
* armv7v7.
* 1 means taking the fastest possible code path without
* platform-specific instructions in C++ code. Note that the compiled
* binary file still contains platform-specific codes.
* 2 means taking the naive code path. Performance is severely
* hampered, but it is less error-prone since the internal
* implementation is rather straightforward.
*
* **Debug level 1 and 2 should not be used in productions.**
*/
static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle,
int debug_level = 0);
public:
enum class HandleType {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
ROCM = 11,
ATLAS = 13,
CAMBRICON = 12,
};

//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};

protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type);

public:
/**
* \brief Create a MegDNN handle from a MegCore Computing handle.
*
* \param[in] computing_handle MegCore computing handle. Please note
* that computing_handle would not be released when this Handle is
* destructed
* \param[in] debug_level
* Applicable for CPU computing handle.
* 0 means taking the fastest possible code path; it may contains
* platform-specific instructions such as SSE for x86_64 or NEON for
* armv7v7.
* 1 means taking the fastest possible code path without
* platform-specific instructions in C++ code. Note that the compiled
* binary file still contains platform-specific codes.
* 2 means taking the naive code path. Performance is severely
* hampered, but it is less error-prone since the internal
* implementation is rather straightforward.
*
* **Debug level 1 and 2 should not be used in productions.**
*/
static std::unique_ptr<Handle> make(
megcoreComputingHandle_t computing_handle, int debug_level = 0);

#if MEGDNN_WITH_CUDA
static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_cuda_operator();
static std::unique_ptr<Handle> make_cuda_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_cuda_operator();
#endif
#if MEGDNN_WITH_ROCM
static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_rocm_operator();
static std::unique_ptr<Handle> make_rocm_handle(
megcoreComputingHandle_t computing_handle);
template <typename opr>
std::unique_ptr<opr> create_rocm_operator();
#endif

virtual ~Handle();

/*!
* \brief Get the underlying megcore computing handle.
*/
megcoreComputingHandle_t megcore_computing_handle() const {
return m_computing_handle;
}

/*!
* \brief set a callback function to be invoked when this handle is
* destructed, so associated resources can be released (e.g.
* computing handle)
*
* This function can be called at most once.
*/
void set_destructor(const thin_function<void()> &d);

/*!
* \brief set a callback to be invoked when an operator is destructed
* \param[in,out] cb the callback function; it would be set to the
* previous callback function
*/
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) {
cb.swap(m_on_opr_destructed);
}

void on_opr_destructed(OperatorBase* opr);

/**
* \brief Create operator of Opr type.
*/
template <typename Opr>
std::unique_ptr<Opr> create_operator();

/*
* =============================================================
* Users should call functions below to query memory requirement.
* =============================================================
*/

/**
* \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes.
*/
virtual size_t alignment_requirement() const;

//! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;

//! get vendor type
virtual HandleVendorType vendor_type() const;

HandleType type() const {
return m_handle_type;
}

/**
* \brief Check is the layout satisfy cross device copy constraint.
* 1. The handle of the src and the dst is the same kind
* 2. The dst is continguous.
*/
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src);

private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
volatile uint32_t m_alive_magic = ALIVE_MAGIC;
megcoreComputingHandle_t m_computing_handle;
const HandleType m_handle_type;
thin_function<void()> m_destructor;
thin_function<void(OperatorBase*)> m_on_opr_destructed;

Handle() = delete;
Handle(const Handle &rhs) = delete;
Handle &operator=(const Handle &rhs) = delete;
virtual ~Handle();

/*!
* \brief Get the underlying megcore computing handle.
*/
megcoreComputingHandle_t megcore_computing_handle() const {
return m_computing_handle;
}

/*!
* \brief set a callback function to be invoked when this handle is
* destructed, so associated resources can be released (e.g.
* computing handle)
*
* This function can be called at most once.
*/
void set_destructor(const thin_function<void()>& d);

/*!
* \brief set a callback to be invoked when an operator is destructed
* \param[in,out] cb the callback function; it would be set to the
* previous callback function
*/
void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) {
cb.swap(m_on_opr_destructed);
}

void on_opr_destructed(OperatorBase* opr);

/**
* \brief Create operator of Opr type.
*/
template <typename Opr>
std::unique_ptr<Opr> create_operator();

/*
* =============================================================
* Users should call functions below to query memory requirement.
* =============================================================
*/

/**
* \brief The internal data pointer of TensorND should be aligned to
* alignment_requirement() in bytes.
*/
virtual size_t alignment_requirement() const;

//! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;

//! get vendor type
virtual HandleVendorType vendor_type() const;

HandleType type() const { return m_handle_type; }

/**
* \brief Check is the layout satisfy cross device copy constraint.
* 1. The handle of the src and the dst is the same kind
* 2. The dst is continguous.
*/
virtual bool check_cross_dev_copy_constraint(const TensorLayout& src);

private:
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u;
volatile uint32_t m_alive_magic = ALIVE_MAGIC;
megcoreComputingHandle_t m_computing_handle;
const HandleType m_handle_type;
thin_function<void()> m_destructor;
thin_function<void(OperatorBase*)> m_on_opr_destructed;

Handle() = delete;
Handle(const Handle& rhs) = delete;
Handle& operator=(const Handle& rhs) = delete;
};

} // namespace megdnn
} // namespace megdnn

#include "megdnn/internal/visibility_epilogue.h"



+ 3
- 2
dnn/include/megdnn/heuristic_cache.h View File

@@ -49,8 +49,9 @@ public:
mutable std::string m_input;

public:
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr,
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0)
Key(Handle* opr_handle, Algorithm::OprType opr_type,
const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size,
const void* param_ptr = nullptr, size_t param_size = 0)
: m_handle{opr_handle},
m_opr_type{static_cast<uint32_t>(opr_type)},
m_inp_layouts_ptr{inp_layouts_ptr},


+ 5
- 6
dnn/include/megdnn/internal/defs.h View File

@@ -16,20 +16,19 @@
* \brief iterate through small (usually used) ndim values
*/
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__)
cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__)

/*!
* \brief iterate through large (rarely used) ndim values
*/
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \
cb(7, ##__VA_ARGS__)
cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__)

/*!
* \brief iterate through all ndim values
*/
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__)
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__)

// vim: syntax=cpp.doxygen

+ 20
- 19
dnn/include/megdnn/internal/opr_header_prologue.h View File

@@ -11,14 +11,14 @@
// intentional no header guard here

#include "megdnn/handle.h"
#include "megdnn/oprs/base.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/opr_result_defs.h"
#include "megdnn/oprs/base.h"

#include "./visibility_prologue.h"

#include <limits>
#include <array>
#include <limits>

#ifndef _megdnn_in
#define _megdnn_in
@@ -29,36 +29,37 @@
#endif

#ifndef _megdnn_tensor_in
#define _megdnn_tensor_in const TensorND &
#define _megdnn_tensor_in const TensorND&
#endif

#ifndef _megdnn_tensor_out
#define _megdnn_tensor_out const TensorND &
#define _megdnn_tensor_out const TensorND&
#endif

#ifndef _megdnn_tensor_inout
#define _megdnn_tensor_inout const TensorND &
#define _megdnn_tensor_inout const TensorND&
#endif

#ifndef _megdnn_workspace
#define _megdnn_workspace const Workspace &
#define _megdnn_workspace const Workspace&
#endif

#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
public: \
_opr_name(Handle *handle): _base_name(handle) {} \
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
public: \
_opr_name(Handle* handle) : _base_name(handle) {}

#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs;

#define DEF_OPR_PARAM(_pname) \
public: \
using Param = param::_pname; \
Param& param() { return m_param; } \
const Param& param() const { return m_param; } \
protected: \
Param m_param
#define DEF_OPR_PARAM(_pname) \
public: \
using Param = param::_pname; \
Param& param() { return m_param; } \
const Param& param() const { return m_param; } \
\
protected: \
Param m_param

// vim: syntax=cpp.doxygen

+ 0
- 1
dnn/include/megdnn/internal/visibility_epilogue.h View File

@@ -20,4 +20,3 @@
#endif

// vim: syntax=cpp.doxygen


+ 16
- 20
dnn/include/megdnn/opr_result_defs.h View File

@@ -16,25 +16,21 @@
namespace megdnn {
namespace opr_result {

struct Checksum {
uint32_t checksum;
union {
int32_t iv;
float fv;
} last_val;

bool operator == (const Checksum &rhs) const {
return checksum == rhs.checksum &&
last_val.iv == rhs.last_val.iv;
}

bool operator != (const Checksum &rhs) const {
return !operator==(rhs);
}
};

} // namespace opr_result
} // namespace megdnn

struct Checksum {
uint32_t checksum;
union {
int32_t iv;
float fv;
} last_val;

bool operator==(const Checksum& rhs) const {
return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv;
}

bool operator!=(const Checksum& rhs) const { return !operator==(rhs); }
};

} // namespace opr_result
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 2
- 4
dnn/include/megdnn/oprs.h View File

@@ -12,11 +12,11 @@

#include "megdnn/oprs/cv.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
#include "megdnn/oprs/nn_int.h"
#include "megdnn/oprs/imgproc.h"
#include "megdnn/oprs/utils.h"
#include "megdnn/oprs/linalg.h"

template <typename Opr>
struct OprArityTrait;
@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1);

#undef INST_ARITY



// vim: syntax=cpp.doxygen

+ 91
- 124
dnn/include/megdnn/oprs/base.h View File

@@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t {
INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5,
INT4X4X16 = 1 << 6,
QINT4x4x32 = 1 << 7,
QINT4x4x32 = 1 << 7,
};

/*!
@@ -195,16 +195,16 @@ public:
Handle::HandleType handle_type() const { return m_handle_type; }

Info::Desc desc() const { return {handle_type(), type(), param(), name()}; }
Info info() const {
return {desc(), attribute()};
}
Info info() const { return {desc(), attribute()}; }

template <typename T>
static void serialize_write_pod(const T& val, std::string& result) {
static_assert(std::is_trivially_copyable<T>::value,
"type should be trivially copyable");
static_assert(!std::is_pointer<T>::value,
"serialize pointer is unsafe in eager execution mode");
static_assert(
std::is_trivially_copyable<T>::value,
"type should be trivially copyable");
static_assert(
!std::is_pointer<T>::value,
"serialize pointer is unsafe in eager execution mode");
result.append(reinterpret_cast<const char*>(&val), sizeof(T));
}

@@ -231,9 +231,8 @@ public:
return ret;
}

static std::string deserialize_read_pod(const std::string& data,
size_t offset = 0,
size_t size = 0) {
static std::string deserialize_read_pod(
const std::string& data, size_t offset = 0, size_t size = 0) {
return std::string(data.data() + offset, size);
}

@@ -286,8 +285,8 @@ public:
* \param layouts origin layouts of the parent opr
* \param opr parent opr
*/
virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&,
const OperatorBase*) const {
virtual std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray&, const OperatorBase*) const {
return {};
}

@@ -333,9 +332,7 @@ public:

ExecutionPolicy& execution_policy() { return m_execution_policy; }

const ExecutionPolicy& execution_policy() const {
return m_execution_policy;
}
const ExecutionPolicy& execution_policy() const { return m_execution_policy; }

virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0;

@@ -355,8 +352,8 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1)) {
ret.emplace_back(algo->info());
@@ -364,8 +361,8 @@ public:
return ret;
}

std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1)) {
ret.emplace_back(algo->info());
@@ -382,12 +379,11 @@ public:
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr)
->info();
}

@@ -408,8 +404,7 @@ protected:
*/
virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
@@ -423,9 +418,8 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2)) {
ret.emplace_back(algo->info());
@@ -433,9 +427,8 @@ public:
return ret;
}

std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) {
ret.emplace_back(algo->info());
@@ -451,14 +444,13 @@ public:
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}

@@ -467,11 +459,9 @@ protected:

//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0;

/**
* \brief Returns the best algorithm by heuristic.
@@ -480,10 +470,8 @@ protected:
* \p workspace_limit_in_bytes.
*/
virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
@@ -497,10 +485,9 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) {
ret.emplace_back(algo->info());
@@ -508,10 +495,9 @@ public:
return ret;
}

std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) {
ret.emplace_back(algo->info());
@@ -527,14 +513,14 @@ public:
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
positive_attr, negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}

@@ -543,11 +529,11 @@ protected:

//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3) = 0;

/**
* \brief Returns the best algorithm by heuristic.
@@ -556,10 +542,9 @@ protected:
* \p workspace_limit_in_bytes.
*/
virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
@@ -573,11 +558,9 @@ public:
using AlgoAttribute = detail::Algorithm::Attribute;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info());
@@ -585,11 +568,9 @@ public:
return ret;
}

std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info());
@@ -605,16 +586,14 @@ public:
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, positive_attr,
negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr,
negative_attr)
->info();
}

@@ -622,14 +601,12 @@ protected:
~MultiAlgoOpr() = default;

//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0;
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4) = 0;
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4) = 0;

/**
* \brief Returns the best algorithm by heuristic.
@@ -638,11 +615,9 @@ protected:
* \p workspace_limit_in_bytes.
*/
virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4,
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};
@@ -657,9 +632,8 @@ public:

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) {
@@ -669,9 +643,8 @@ public:
}

std::vector<AlgorithmInfo> get_all_algorithms_info_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) {
@@ -687,17 +660,15 @@ public:
* The selected algorithm should not use workspace more than
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, positive_attr,
negative_attr)
return get_algorithm_heuristic(
p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes,
positive_attr, negative_attr)
->info();
}

@@ -705,15 +676,13 @@ protected:
~MultiAlgoOpr() = default;

//! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
virtual std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) = 0;
virtual std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) = 0;

/**
@@ -723,12 +692,10 @@ protected:
* \p workspace_limit_in_bytes.
*/
virtual Algorithm* get_algorithm_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2,
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(),
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0;
};


+ 123
- 99
dnn/include/megdnn/oprs/cv.h View File

@@ -31,15 +31,17 @@ class FlipForward : public FlipBase {
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using Flip = FlipForward;

@@ -56,15 +58,17 @@ class RotateForward : public RotateBase {
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using Rotate = RotateForward;

@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase {
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using ROICopy = ROICopyForward;

@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase {
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using CvtColor = CvtColorForward;

@@ -130,8 +138,9 @@ public:
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst);
std::string param_msg() const;
int get_real_coord(int p, int len);
};
@@ -148,15 +157,17 @@ public:
* \warning src, trans, border_value, dst should be contiguous
* The size of trans is N * 2 * 3
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& trans,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& trans,
const TensorLayout& dst, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using WarpAffine = WarpAffineForward;

@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase {
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using GaussianBlur = GaussianBlurForward;

@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase {
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using Resize = ResizeForward;

@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase {
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);

public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& mat) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& mat) = 0;

protected:
void check_exec(const TensorLayout& diff, const TensorLayout& mat,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& diff, const TensorLayout& mat,
size_t workspace_in_bytes);
};

/**
@@ -251,29 +268,32 @@ public:
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst);
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy,
TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst);
void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
};

class RemapForward : public RemapBase {
DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy,
TensorLayout& dst);
void deduce_layout(
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);

virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& dst) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& dst, size_t workspace_in_bytes);
};
using Remap = RemapForward;

@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase {
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);

public:
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad) = 0;

protected:
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& map_xy, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
};

class RemapBackwardMat : public RemapBase {
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;

virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& map_xy,
const TensorLayout& diff,
const TensorLayout& grad) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& map_xy,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
};

class SeparableFilterBase : public OperatorBase {
@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase {
DEF_OPR_PARAM(SeparableFilter);

protected:
void deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y,
const TensorLayout& dst);
void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst);
};

class SeparableFilterForward : public SeparableFilterBase {
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter_x,
const TensorLayout& filter_y,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter_x,
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst,
size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& filter_x,
const TensorLayout& filter_y, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using SeparableFilter = SeparableFilterForward;



+ 847
- 822
dnn/include/megdnn/oprs/general.h
File diff suppressed because it is too large
View File


+ 164
- 180
dnn/include/megdnn/oprs/imgproc.h View File

@@ -13,173 +13,162 @@

namespace megdnn {

class WarpPerspectiveBase: public OperatorBase {
class WarpPerspectiveBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase);
DEF_OPR_PARAM(WarpPerspective);
public:
using InterpolationMode = Param::InterpolationMode;
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
const TensorLayout &dst) {
check_layout_fwd(src, mat, {}, dst);
}

void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat,
const TensorLayout &mat_idx, const TensorLayout &dst);
std::string param_msg() const;
int get_real_coord(int p, int len);

public:
using InterpolationMode = Param::InterpolationMode;
using BorderMode = Param::BorderMode;

protected:
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
check_layout_fwd(src, mat, {}, dst);
}

void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst);
std::string param_msg() const;
int get_real_coord(int p, int len);
};

class WarpPerspectiveForward: public WarpPerspectiveBase {
class WarpPerspectiveForward : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1);
public:
/**
* \param[in] src (n, channel, in_height, in_width)
* \param[in] mat (n, 3, 3)
* \param[out] dst (n, channel, out_height, out_width)
*
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
*
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
*
* src and dst can have different shapes, as long as their n and c agree.
* src, mat and dst should be contiguous.
*/
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec(src, mat, {}, dst, workspace);
}

/**
* \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* of [0, m-1].
*
* \param mat_idx the indices of input image that each matrix in \p mat
* should act on. It can also be empty and in such case \p mat
* should have the same batch size as \p src.
*/
virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &dst) {
return get_workspace_in_bytes(src, mat, {}, dst);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes);

void check_exec_allow_nhwc_mat_idx(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes);

public:
/**
* \param[in] src (n, channel, in_height, in_width)
* \param[in] mat (n, 3, 3)
* \param[out] dst (n, channel, out_height, out_width)
*
* \see
* http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine
*
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2]
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator,
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator)
*
* src and dst can have different shapes, as long as their n and c agree.
* src, mat and dst should be contiguous.
*/
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec(src, mat, {}, dst, workspace);
}

/**
* \p src should have batch size m, and \p mat and \p mat_idx should
* both have batch size n. Each item in \p mat_idx must be in the range
* of [0, m-1].
*
* \param mat_idx the indices of input image that each matrix in \p mat
* should act on. It can also be empty and in such case \p mat
* should have the same batch size as \p src.
*/
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) {
return get_workspace_in_bytes(src, mat, {}, dst);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes);

void check_exec_allow_nhwc_mat_idx(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& dst,
size_t workspace_in_bytes);
};
using WarpPerspective = WarpPerspectiveForward;

class WarpPerspectiveBackwardData: public WarpPerspectiveBase {
class WarpPerspectiveBackwardData : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1);
public:
/**
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. src
* \param[out] workspace temporary workspace to perform backward
*/
void exec(_megdnn_tensor_in mat,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(mat, {}, diff, grad, workspace);
}

virtual void exec(_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad) {
return get_workspace_in_bytes(mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad) = 0;
protected:
void check_exec(const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);

public:
/**
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. src
* \param[out] workspace temporary workspace to perform backward
*/
void exec(
_megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(mat, {}, diff, grad, workspace);
}

virtual void exec(
_megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& diff,
const TensorLayout& grad) {
return get_workspace_in_bytes(mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& mat, const TensorLayout& mat_idx,
const TensorLayout& diff, const TensorLayout& grad) = 0;

protected:
void check_exec(
const TensorLayout& mat, const TensorLayout& mat_idx,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_in_bytes);
};

class WarpPerspectiveBackwardMat: public WarpPerspectiveBase {
class WarpPerspectiveBackwardMat : public WarpPerspectiveBase {
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1);
public:
/**
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. mat
* \param[out] workspace temporary workspace to perform backward
*/
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
exec(src, mat, {}, diff, grad, workspace);
}

virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &diff,
const TensorLayout &grad) {
return get_workspace_in_bytes(src, mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad) = 0;
protected:
void check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &diff,
const TensorLayout &grad,
size_t workspace_in_bytes);

public:
/**
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec
* \param[in] diff the backpropagated gradient wrt. dst
* \param[out] grad the backpropagated gradient wrt. mat
* \param[out] workspace temporary workspace to perform backward
*/
void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) {
exec(src, mat, {}, diff, grad, workspace);
}

virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;

size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff,
const TensorLayout& grad) {
return get_workspace_in_bytes(src, mat, {}, diff, grad);
}

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad) = 0;

protected:
void check_exec(
const TensorLayout& src, const TensorLayout& mat,
const TensorLayout& mat_idx, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_in_bytes);
};

class DctChannelSelectForward : public OperatorBase {
@@ -194,37 +183,32 @@ public:
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor
* \param[out] workspace temporary workspace to perform forward
*/
virtual void exec(_megdnn_tensor_in src,
_megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

void deduce_layout(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_in mask_offset,
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;

void deduce_layout(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, TensorLayout& dst);

virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, const TensorLayout& dst) = 0;

protected:
void check_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
const TensorLayout& dst);
void deduce_layout_fwd(const TensorLayout& src,
const TensorLayout& mask_offset,
const TensorLayout& mask_val,
TensorLayout& dst);
void check_layout_fwd(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, const TensorLayout& dst);

void deduce_layout_fwd(
const TensorLayout& src, const TensorLayout& mask_offset,
const TensorLayout& mask_val, TensorLayout& dst);

std::string param_msg() const;
};

} // namespace megdnn
} // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"



+ 55
- 54
dnn/include/megdnn/oprs/linalg.h View File

@@ -33,22 +33,22 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType &C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;

static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
}

protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
};
using BatchedMatrixMul = BatchedMatrixMulForward;

@@ -70,24 +70,24 @@ public:
* op(A) = A if transposeA is false, otherwise op(A) = A^t.
* op(B) = B if transposeB is false, otherwise op(B) = B^t.
*/
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_dtype(DType A, DType B, DType& C);
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;

static size_t pack_size (const Param::Format format);
static size_t pack_size(const Param::Format format);

static Algorithm::OprType get_opr_type() {
return Algorithm::OprType::MATRIX_MUL_FORWARD;
}

protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
};
using MatrixMul = MatrixMulForward;

@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase {
DEF_OPR_PARAM(Empty);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst);
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);

protected:
/*!
@@ -116,8 +116,7 @@ protected:
*
* Note that \p batch and \p n can be null
*/
static void canonize_params(const TensorLayout& layout, size_t* batch,
size_t* n);
static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);

/*!
* \brief canonize and validate input params for exec() impls
@@ -125,11 +124,12 @@ protected:
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not
* be null
*/
void check_exec(const TensorLayout& src, const TensorLayout& dst,
_megdnn_workspace workspace, size_t* batch, size_t* n);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
_megdnn_workspace workspace, size_t* batch, size_t* n);

virtual size_t get_workspace_in_bytes(size_t batch, size_t n,
size_t dtype_size) = 0;
virtual size_t get_workspace_in_bytes(
size_t batch, size_t n, size_t dtype_size) = 0;
};

//! inter-product of two vectors
@@ -147,17 +147,17 @@ public:
* A, B, C must be contiguous. A and B must have the same 1-dimensional
* shape and non-negative strides. C must be scalar.
*/
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B,
TensorLayout& C);
virtual size_t get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) = 0;
virtual void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
virtual size_t get_workspace_in_bytes(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;

protected:
void check_exec(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_in_bytes);
};
using Dot = DotForward;

@@ -193,23 +193,24 @@ public:
* if compute_uv is false (default to true).
*
*/
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u,
_megdnn_tensor_out s, _megdnn_tensor_out vt,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& u,
TensorLayout& s, TensorLayout& vt);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt);
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
_megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, TensorLayout& u, TensorLayout& s,
TensorLayout& vt);
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt);

protected:
static void canonize_params(const TensorLayout& layout, size_t* batch,
size_t* m, size_t* n);
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n,
size_t dtype_size) = 0;
void check_exec(const TensorLayout& src, const TensorLayout& u,
const TensorLayout& s, const TensorLayout& vt,
size_t workspace_in_bytes);
static void canonize_params(
const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
virtual size_t get_workspace_in_bytes(
size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
void check_exec(
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
const TensorLayout& vt, size_t workspace_in_bytes);
};

using SVD = SVDForward;


+ 629
- 618
dnn/include/megdnn/oprs/nn.h
File diff suppressed because it is too large
View File


+ 5
- 8
dnn/include/megdnn/oprs/nn_int.h View File

@@ -36,7 +36,7 @@ public:
struct ModeTrait {
uint32_t arity = 0; //!< number of inputs needed
CheckDtypeFunc check_inp[MAX_ARITY];
SetOrCheckDtypeFunc check_out; //!< dtype of output var
SetOrCheckDtypeFunc check_out; //!< dtype of output var
bool need_specify_out_dtype =
false; //!< the dtype should be setup externally, otherwise
//!< would be inferred by check_out(dtype, false)
@@ -46,13 +46,10 @@ public:
static const ModeTrait& from_mode(Mode mode);
};

virtual void exec(_megdnn_in const TensorNDArray& src,
_megdnn_tensor_out dst) = 0;
virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0;

//! get trait of current mode
const ModeTrait& mode_trait() const {
return ModeTrait::from_mode(m_param.mode);
}
const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); }

//! deduce output layout
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst);
@@ -60,8 +57,8 @@ public:
protected:
//! throw exception if incorrect layout; broadcast input shape to
//! output shape
void check_layout_and_broadcast(const TensorLayoutPtrArray& src,
const TensorLayout& dst);
void check_layout_and_broadcast(
const TensorLayoutPtrArray& src, const TensorLayout& dst);
};

} // namespace megdnn


+ 103
- 87
dnn/include/megdnn/oprs/utils.h View File

@@ -15,84 +15,97 @@
namespace megdnn {

//! base class for random number generators
class RNGBase: public OperatorBase {
class RNGBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase);
public:
virtual void exec(_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
protected:
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0;

public:
virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0;

protected:
virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0;
};

//! sample from poisson distribution
class PoissonRNG: public OperatorBase {
class PoissonRNG : public OperatorBase {
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1);
DEF_OPR_PARAM(PoissonRNG);
public:
virtual void exec(_megdnn_tensor_in lam,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &lam,
const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &lam, const TensorLayout &dst,
size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in lam, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& lam, const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& lam, const TensorLayout& dst,
size_t workspace_in_bytes);
};

//! sample from beta distribution
class BetaRNG: public OperatorBase {
class BetaRNG : public OperatorBase {
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(BetaRNG);
public:
virtual void exec(_megdnn_tensor_in alpha,
_megdnn_tensor_in beta,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha,
const TensorLayout &beta, const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &alpha, const TensorLayout &beta,
const TensorLayout &dst, size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& alpha, const TensorLayout& beta,
const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& alpha, const TensorLayout& beta,
const TensorLayout& dst, size_t workspace_in_bytes);
};

//! sample from gamma distribution
class GammaRNG: public OperatorBase {
class GammaRNG : public OperatorBase {
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1);
DEF_OPR_PARAM(GammaRNG);
public:
virtual void exec(_megdnn_tensor_in shape,
_megdnn_tensor_in scale,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout &shape,
const TensorLayout &scale, const TensorLayout &dst) = 0;
protected:
void check_exec(const TensorLayout &shape, const TensorLayout &scale,
const TensorLayout &dst, size_t workspace_in_bytes);

public:
virtual void exec(
_megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& shape, const TensorLayout& scale,
const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& shape, const TensorLayout& scale,
const TensorLayout& dst, size_t workspace_in_bytes);
};

//! sample from uniform distribution on the interval (0, 1]
class UniformRNG: public RNGBase {
class UniformRNG : public RNGBase {
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(UniformRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};

//! sample from gaussian distribution
class GaussianRNG: public RNGBase {
class GaussianRNG : public RNGBase {
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(GaussianRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};

class PermutationRNG: public RNGBase {
class PermutationRNG : public RNGBase {
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1);
DEF_OPR_PARAM(PermutationRNG);
protected:
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);

protected:
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes);
};

class ShuffleRNGForward : public OperatorBase {
@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG);

public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst,
TensorLayout& indices);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) = 0;
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& src, TensorLayout& dst, TensorLayout& indices);
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices) = 0;

protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
};
using ShuffleRNG = ShuffleRNGForward;

@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase {
DEF_OPR_PARAM(ShuffleRNG);

public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad) = 0;
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad,
_megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad) = 0;

protected:
void check_exec(const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
void check_exec(
const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
};

/*!
* \brief sleep for specific time on the computing device; useful for testing
* async problems
*/
class SleepForward: public OperatorBase {
class SleepForward : public OperatorBase {
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0);
DEF_OPR_PARAM(Sleep);

public:
virtual void exec() = 0;
public:
virtual void exec() = 0;
};
using Sleep = SleepForward;

@@ -149,20 +165,19 @@ using Sleep = SleepForward;
*
* data must be a one-dimensional contiguous tensor with dtype byte
*/
class ChecksumForward: public OperatorBase {
class ChecksumForward : public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1);

public:
using Result = opr_result::Checksum;
public:
using Result = opr_result::Checksum;

virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;

virtual Result exec(_megdnn_tensor_in data,
_megdnn_workspace workspace) = 0;
virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0;

protected:
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes);
protected:
void check_exec(const TensorLayout& layout, size_t workspace_in_bytes);
};
using Checksum = ChecksumForward;

@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2);

public:
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1,
const TensorLayout& layout2) = 0;
public:
virtual size_t get_workspace_in_bytes(
const TensorLayout& layout1, const TensorLayout& layout2) = 0;

virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2,
_megdnn_workspace workspace) = 0;
virtual float exec(
_megdnn_tensor_in src1, _megdnn_tensor_in src2,
_megdnn_workspace workspace) = 0;

protected:
void check_exec(const TensorLayout& layout1,
const TensorLayout& layout2, size_t workspace_in_bytes);
protected:
void check_exec(
const TensorLayout& layout1, const TensorLayout& layout2,
size_t workspace_in_bytes);
};


bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format);
bool check_bias_share_in_channel(
const TensorLayout& bias, const param::ConvBias::Format format);

} // namespace megdnn



+ 36
- 47
dnn/include/megdnn/tensor_format.h View File

@@ -18,9 +18,9 @@
namespace megdnn {

enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
LOWBITS_ALIGNED_TO_BYTE = 2, //!<
DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
LOWBITS_ALIGNED_TO_BYTE = 2, //!<
};

class TensorFormat::ImplBase {
@@ -33,8 +33,7 @@ public:

virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;

virtual TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const = 0;
virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0;

virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0;

@@ -79,8 +78,7 @@ public:
*/
bool is_contiguous_spec(const TensorLayout& layout) const override;

TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;

TensorLayout::Span span_spec(const TensorLayout& layout) const override;

@@ -88,8 +86,7 @@ public:
void serialize_append(std::string& result) const override;

static TensorFormat make();
static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);
};

namespace detail {
@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase {
size_t m_align_axis, m_align_size_in_elements_log2;

protected:
Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_elements);
Image2DTensorFormatBase(
Type type, size_t align_axis, size_t align_size_in_elements);
virtual ~Image2DTensorFormatBase() = default;

public:
@@ -129,9 +126,7 @@ public:

size_t align_axis() const { return m_align_axis; }

size_t align_size_in_elements_log2() const {
return m_align_size_in_elements_log2;
}
size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; }

std::string to_string() const override;

@@ -145,6 +140,7 @@ public:
size_t image_height(const TensorLayout& layout) const;

void serialize_append(std::string& result) const override;

protected:
struct SerializePack {
uint8_t align_axis;
@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
* align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size
*/
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
const TensorLayout& layout) const;
size_t image_pitch_alignment_in_bytes(
size_t align_size_in_elements, const TensorLayout& layout) const;

protected:
Image2DPackedTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis,
align_size_in_elements),
Image2DPackedTensorFormatBase(
Type type, size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements),
m_vendor_type(vendor_type) {}

virtual ~Image2DPackedTensorFormatBase() = default;
@@ -197,13 +192,12 @@ public:

bool is_contiguous_spec(const TensorLayout& layout) const override;

TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
};
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;

/*!
* \brief used for tensors storing lowbit data
* \brief used for tensors storing lowbit data
*
* \param m_size_nbits size in bits of elements in the tensor
* \param m_align_size_in_bits aligned size in bits
@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;

protected: //?
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits,
size_t align_size_in_bits);
LowbitsAlignedTensorFormatBase(
Type type, size_t size_nbits, size_t align_size_in_bits);

virtual ~LowbitsAlignedTensorFormatBase() = default;

public:
size_t align_size_in_bits() const { return m_align_size_in_bits; }
size_t size_nbits() const { return m_size_nbits; }

std::string to_string() const override;
@@ -238,8 +232,8 @@ public:

bool is_contiguous_spec(const TensorLayout& layout) const override;

TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override;
protected:
struct SerializePack {
uint8_t size_nbits;
@@ -254,16 +248,14 @@ protected:
*
* This is used for OpenCL.
*/
class Image2DPack4TensorFormat final
: public detail::Image2DPack4TensorFormatBase {
class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase {
public:
static constexpr Type TYPE = Type::IMAGE2D_PACK4;

//! for internal usage or test purposes
static TensorFormat make_raw(size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type =
Handle::HandleVendorType::NOT_SPEC);
static TensorFormat make_raw(
size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC);

static TensorFormat make(size_t align_axis, const Handle* handle);

@@ -273,13 +265,11 @@ public:
* Note that the alignment may be different if deserialized on another
* handle
*/
static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);

static bool is_valid_image(const TensorLayout& layout) {
if (layout.format.type() == TYPE) {
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(
layout);
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout);
return true;
}
return false;
@@ -288,8 +278,9 @@ public:
TensorFormat change_axis(size_t axis) const override;

private:
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
Image2DPack4TensorFormat(
size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DPack4TensorFormatBase(
TYPE, align_axis, align_size_in_elements, vendor_type) {}
};
@@ -306,13 +297,12 @@ public:

static TensorFormat make(size_t size_nbits);

static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size);

static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) {
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>()
.assert_valid(layout);
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid(
layout);
return true;
}
return false;
@@ -320,8 +310,7 @@ public:

private:
LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits,
BYTE_IN_BITS) {}
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {}
};
} // namespace megdnn



+ 3
- 5
dnn/include/megdnn/tensor_iter.h View File

@@ -167,13 +167,11 @@ public:

TensorIter(const TensorND& tensor) : m_tensor(tensor) {}

Iter begin() const {
return Iter::make(const_cast<TensorND&>(m_tensor), 0);
}
Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); }

Iter end() const {
return Iter::make(const_cast<TensorND&>(m_tensor),
m_tensor.layout.total_nr_elems());
return Iter::make(
const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems());
}
};
/*!


+ 5
- 5
dnn/include/megdnn/thin/function.h View File

@@ -11,19 +11,19 @@

#pragma once

#include <type_traits>
#include <cstdlib>
#include <functional>
#include <utility>
#include <memory>
#include <cstdlib>
#include <type_traits>
#include <utility>

#include "megdnn/internal/visibility_prologue.h"

namespace megdnn {
template<typename Signature>
template <typename Signature>
using thin_function = ::std::function<Signature>;

} // namespace megdnn
} // namespace megdnn

#include "megdnn/internal/visibility_epilogue.h"



+ 42
- 76
dnn/include/megdnn/thin/small_vector.h View File

@@ -58,18 +58,16 @@ protected:
m_end_ptr(first_elm),
m_capacity_ptr(static_cast<char*>(first_elm) + size) {}

void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes,
size_t type_size);
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);

public:
size_t size_in_bytes() const {
return size_t(static_cast<char*>(m_end_ptr) -
static_cast<char*>(m_begin_ptr));
return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr));
}

size_t capacity_in_bytes() const {
return size_t(static_cast<char*>(m_capacity_ptr) -
static_cast<char*>(m_begin_ptr));
return size_t(
static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr));
}

bool empty() const { return m_begin_ptr == m_end_ptr; }
@@ -85,20 +83,15 @@ private:
U m_first_elm;

protected:
SmallVectorTemplateCommon(size_t size)
: SmallVectorBase(&m_first_elm, size) {}
SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {}

void grow_pod(size_t min_sz_in_bytes, size_t type_size) {
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size);
}

bool is_small() {
return m_begin_ptr == static_cast<const void*>(&m_first_elm);
}
bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); }

void reset_to_small() {
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm;
}
void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; }

void set_end(T* p) { m_end_ptr = p; }

@@ -128,20 +121,12 @@ protected:
public:
// forwarding iterator creation
iterator begin() { return static_cast<iterator>(m_begin_ptr); }
const_iterator begin() const {
return static_cast<const_iterator>(m_begin_ptr);
}
const_iterator cbegin() const {
return static_cast<const_iterator>(m_begin_ptr);
}
const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); }
const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); }

iterator end() { return static_cast<iterator>(m_end_ptr); }
const_iterator end() const {
return static_cast<const_iterator>(m_end_ptr);
}
const_iterator cend() const {
return static_cast<const_iterator>(m_end_ptr);
}
const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); }
const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); }

reference at(size_type idx) {
if (idx >= size()) {
@@ -167,13 +152,9 @@ public:

// reverse iterator creation method.
reverse_iterator rbegin() { return reverse_iterator(end()); }
const_reverse_iterator rbegin() const {
return const_reverse_iterator(end());
}
const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); }
reverse_iterator rend() { return reverse_iterator(begin()); }
const_reverse_iterator rend() const {
return const_reverse_iterator(begin());
}
const_reverse_iterator rend() const { return const_reverse_iterator(begin()); }

pointer data() { return pointer(begin()); }
const_pointer data() const { return const_pointer(begin()); }
@@ -207,8 +188,8 @@ protected:

template <typename It1, typename It2>
static void uninitialized_move(It1 first, It1 last, It2 dest) {
std::uninitialized_copy(std::make_move_iterator(first),
std::make_move_iterator(last), dest);
std::uninitialized_copy(
std::make_move_iterator(first), std::make_move_iterator(last), dest);
}

template <typename It1, typename It2>
@@ -293,9 +274,7 @@ protected:
memcpy(dest, first, (last - first) * sizeof(T));
}

void grow(size_t min_sz = 0) {
this->grow_pod(min_sz * sizeof(T), sizeof(T));
}
void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); }

public:
void push_back(const T& _elm) {
@@ -318,8 +297,7 @@ public:
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N
*/
template <typename T>
class SmallVectorImpl
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>;

public:
@@ -329,8 +307,7 @@ public:

protected:
explicit SmallVectorImpl(unsigned n)
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {
}
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {}

public:
SmallVectorImpl(const SmallVectorImpl&) = delete;
@@ -354,8 +331,7 @@ public:
} else if (n > this->size()) {
if (this->capacity() < n)
this->grow(n);
for (auto it = this->end(), end = this->begin() + n; it != end;
++it)
for (auto it = this->end(), end = this->begin() + n; it != end; ++it)
new (&*it) T();
this->set_end(this->begin() + n);
}
@@ -389,10 +365,11 @@ public:
void swap(SmallVectorImpl<T>& rhs);

/// Add the specified range to the end of the SmallVector.
template <typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
template <
typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void append(in_iter in_start, in_iter in_end) {
size_type num_inputs = std::distance(in_start, in_end);
// Grow allocated space if needed.
@@ -432,10 +409,11 @@ public:
std::uninitialized_fill(this->begin(), this->end(), elm);
}

template <typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
template <
typename in_iter,
typename = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<in_iter>::iterator_category,
std::input_iterator_tag>::value>::type>
void assign(in_iter in_start, in_iter in_end) {
clear();
append(in_start, in_end);
@@ -571,8 +549,7 @@ public:
std::fill_n(it, num_overwritten, elm);

// Insert the non-overwritten middle part.
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten,
elm);
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm);
return it;
}

@@ -646,8 +623,7 @@ public:
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
this->grow();
}
new (static_cast<void*>(this->end()))
T(std::forward<ArgTypes>(args)...);
new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...);
this->set_end(this->end() + 1);
}

@@ -661,13 +637,11 @@ public:
return std::equal(this->begin(), this->end(), rhs.begin());
}

bool operator!=(const SmallVectorImpl<T>& rhs) const {
return !(*this == rhs);
}
bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); }

bool operator<(const SmallVectorImpl<T>& rhs) const {
return std::lexicographical_compare(this->begin(), this->end(),
rhs.begin(), rhs.end());
return std::lexicographical_compare(
this->begin(), this->end(), rhs.begin(), rhs.end());
}
};

@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
// Copy over the extra elms.
if (this->size() > rhs.size()) {
size_t elm_diff = this->size() - rhs.size();
this->uninitialized_move(this->begin() + num_shared, this->end(),
rhs.end());
this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end());
rhs.set_end(rhs.end() + elm_diff);
this->destroy_range(this->begin() + num_shared, this->end());
this->set_end(this->begin() + num_shared);
} else if (rhs.size() > this->size()) {
size_t elm_diff = rhs.size() - this->size();
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(),
this->end());
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end());
this->set_end(this->end() + elm_diff);
this->destroy_range(rhs.begin() + num_shared, rhs.end());
rhs.set_end(rhs.begin() + num_shared);
@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
}

template <typename T>
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(
const SmallVectorImpl<T>& rhs) {
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) {
if (this == &rhs)
return *this;
size_t rhs_sz = rhs.size();
@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(
} else if (cur_sz) {
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin());
}
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(),
this->begin() + cur_sz);
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
this->set_end(this->begin() + rhs_sz);
return *this;
}
@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) {
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin());
}

this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(),
this->begin() + cur_sz);
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);

this->set_end(this->begin() + rhs_sz);

@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> {
public:
SmallVector() : SmallVectorImpl<T>(N) {}

explicit SmallVector(size_t size, const T& value = T())
: SmallVectorImpl<T>(N) {
explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) {
this->assign(size, value);
}

@@ -901,15 +869,13 @@ namespace std {

/// Implement std::swap in terms of SmallVector swap.
template <typename T>
inline void swap(megdnn::SmallVectorImpl<T>& lhs,
megdnn::SmallVectorImpl<T>& rhs) {
inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) {
lhs.swap(rhs);
}

/// Implement std::swap in terms of SmallVector swap.
template <typename T, unsigned N>
inline void swap(megdnn::SmallVector<T, N>& lhs,
megdnn::SmallVector<T, N>& rhs) {
inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) {
lhs.swap(rhs);
}
} // end namespace std


+ 6
- 6
dnn/include/megdnn/version.h View File

@@ -17,13 +17,13 @@
#include "megdnn/internal/visibility_prologue.h"

namespace megdnn {
struct Version {
int major, minor, patch;
};
struct Version {
int major, minor, patch;
};

//! get megdnn version of the binary
Version get_version();
}
//! get megdnn version of the binary
Version get_version();
} // namespace megdnn

#include "megdnn/internal/visibility_epilogue.h"



+ 26
- 24
dnn/src/aarch64/conv_bias/fp16/algos.cpp View File

@@ -22,18 +22,17 @@ using namespace aarch64;
/* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)

bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
bool ConvBiasImpl::AlgoF16DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
return get_kimpls(param);
@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
return {};
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
using Func = std::function<void(
const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t,
size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp16::conv_stride2::do_conv_2x2_stride2;
@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto do_conv = [bundle, conv](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
ncb_index.ndrange_id);
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}


+ 5
- 5
dnn/src/aarch64/conv_bias/fp16/algos.h View File

@@ -18,13 +18,13 @@ namespace aarch64 {
/* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;

public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMV8F16STRD2"; }
bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(const NCBKernSizeParam& param) const override;



+ 48
- 67
dnn/src/aarch64/conv_bias/fp16/stride2_kern.h View File

@@ -20,9 +20,9 @@ namespace aarch64 {
namespace fp16 {
namespace conv_stride2 {

static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_2x2_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3;
size_t mod4_left = width & 3;
@@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
"5: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1)
: "r"(mod4_left), "w"(_k0123)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v28", "v29", "v30", "v31");

r0 += tail_step;
r1 += tail_step;
@@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter,
}
}

static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_3x3_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3;
size_t mod3_left = width % 3;
@@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
"3: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2)
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29");

r0 += tail_step;
r1 += tail_step;
@@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter,
}
}

static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_5x5_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3;
size_t mod2_left = width & 1;
@@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
const __fp16* r4 = src_ptr + IW * 4;

register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") =
MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") =
MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") =
MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") =
MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") =
MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") =
MEGDNN_SIMD_SET1(filter[24]);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]);

for (size_t i = 0; i < OH; i++) {
asm volatile(
@@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
"bne 2b \n"
"3: \n"

: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2),
"+r"(r3), "+r"(r4)
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4)
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415),
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424),
"r"(mod2_left)
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31");
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left)
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");

r0 += tail_step;
r1 += tail_step;
@@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter,
}
}

static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
__fp16* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_7x7_stride2(
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 3;

@@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
const __fp16* r6 = src_ptr + IW * 6;

register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") =
MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") =
MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") =
MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") =
MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") =
MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") =
MEGDNN_SIMD_LOADU(filter + 24);
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") =
MEGDNN_SIMD_LOADU(filter + 28);
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") =
MEGDNN_SIMD_LOADU(filter + 32);
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") =
MEGDNN_SIMD_LOADU(filter + 36);
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4);
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8);
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12);
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16);
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20);
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24);
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28);
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32);
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36);
register MEGDNN_SIMD_TYPE _k40414243 asm("v10") =
MEGDNN_SIMD_LOADU(filter + 40);
register MEGDNN_SIMD_TYPE _k44454647 asm("v11") =
MEGDNN_SIMD_LOADU(filter + 44);
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") =
MEGDNN_SIMD_SET1(filter[48]);
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]);

for (size_t i = 0; i < OH; i++) {
asm volatile(
@@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter,
"bne 2b \n"
"3: \n"

: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4), "+r"(r5), "+r"(r6)
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4),
"+r"(r5), "+r"(r6)
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011),
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223),
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435),
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647),
"w"(_k48484848)
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848)
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31");

r0 += tail_step;
r1 += tail_step;


+ 30
- 31
dnn/src/aarch64/conv_bias/fp32/algos.cpp View File

@@ -21,18 +21,17 @@ using namespace megdnn;
using namespace aarch64;

MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32)
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] &&
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
}
MIDOUT_END();
@@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param);
@@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
return {};
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
@@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
bool large_group = group >= param.nr_threads;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
using Func = std::function<void(
const float*, const float*, float*, size_t, size_t, size_t, size_t,
size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp32::conv_stride2::do_conv_2x2_stride2;
@@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
conv = fp32::conv_stride2::do_conv_7x7_stride2;
}

WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, large_group);
WorkspaceBundle bundle =
arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride(
param, large_group);
SmallVector<NCBKern> ret_kerns;

//! Dense conv and small group
@@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
bundle.set(kern_param.workspace_ptr);
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
copy_padding_kern_stride(
bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
{ncb_index.thread_id,
0, oc});
arm_common::MultithreadDirectConvCommon<float, float>::
do_conv_kern_stride(
bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto copy_padding = [bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
copy_padding_kern_stride(
bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
auto do_conv = [bundle, conv](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
bundle.set(kern_param.workspace_ptr);
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
ncb_index.ndrange_id);
arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride(
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}


+ 5
- 5
dnn/src/aarch64/conv_bias/fp32/algos.h View File

@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl;

class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;

public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMV8F32STRD2"; }

bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;

size_t get_workspace(const NCBKernSizeParam& param) const override;



+ 33
- 38
dnn/src/aarch64/conv_bias/fp32/stride2_kern.h View File

@@ -16,16 +16,15 @@

namespace megdnn {
namespace aarch64 {
namespace fp32{
namespace fp32 {
namespace conv_stride2 {


//! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp`

// refer to function do_conv_2x2_stride2_asm_unroll4
static void do_conv_2x2_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_2x2_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2;
size_t mod4_left = width & 3;
@@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter,
"5: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1)
: "r"(mod4_left), "w"(_k0123)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v28", "v29", "v30", "v31");

r0 += tail_step;
r1 += tail_step;
@@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter,
}

// refer to function do_conv_3x3_stride2_asm_unroll3
static void do_conv_3x3_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_3x3_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2;
size_t mod3_left = width % 3;
@@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
"ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6

"ld2 {v5.4s, v6.4s}, [%3], #32 \n"
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ...
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ...
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i]
"ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8
"fmla v0.4s, v2.4s, v22.4s \n"
@@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
"3: \n"
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2)
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29");
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29");

r0 += tail_step;
r1 += tail_step;
@@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter,
}

// refer to function do_conv_5x5_stride2_asm_unroll2
static void do_conv_5x5_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_5x5_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2;
size_t mod2_left = width & 1;
@@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter,
"bne 2b \n"
"3: \n"

: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2),
"+r"(r3), "+r"(r4)
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4)
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415),
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424),
"r"(mod2_left)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24");
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24");

r0 += tail_step;
r1 += tail_step;
@@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter,
}

// refer to function do_conv_7x7_stride2_asm_unroll2
static void do_conv_7x7_stride2(const float* src, const float* filter,
float* dst, size_t IH, size_t IW, size_t OH,
size_t OW, size_t IC) {
static void do_conv_7x7_stride2(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t OH, size_t OW, size_t IC) {
const size_t tail_step = IW - 2 * OW + IW;
size_t width = OW >> 2;

@@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter,
"bne 2b \n"
"3: \n"

: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3),
"+r"(r4), "+r"(r5), "+r"(r6)
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4),
"+r"(r5), "+r"(r6)
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011),
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223),
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435),
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647),
"w"(_k48484848)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5",
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18");
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848)
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18");

r0 += tail_step;
r1 += tail_step;


+ 36
- 37
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
size_t N = OH * OW;

#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
@@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
DISPATCH_GEMM_BIAS(s8_4x4, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
@@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
return {nullptr, {part0, part1, part2}};
}

void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
@@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<true>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<false>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
Workspace workspace(
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);

if (cpuinfo_has_arm_neon_dot()) {
@@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
DISPATCH_GEMM_BIAS(s8_4x4, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \
} \
MIDOUT_END()
DISPATCH_GEMM_BIAS(s8_4x4, 0)
#endif


+ 6
- 8
dnn/src/aarch64/conv_bias/int8/algos.h View File

@@ -12,8 +12,8 @@
#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);

public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "S8MATMUL"; }

bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override {
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}


+ 61
- 70
dnn/src/aarch64/conv_bias/int8/strategy.cpp View File

@@ -29,9 +29,10 @@ struct KernCaller;
#if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 12> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
static void run(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 8;
@@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> {
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12,
is_first_k);
matmul_8x12x4::kern_8x12(
packA, cur_packB, K, workspace, 12, is_first_k);

arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8,
12>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4));

#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \
@@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> {
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12,
is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x12x4::kern_4x12(
packA, cur_packB, K, workspace, 12, is_first_k,
std::min<size_t>(M - m, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \
bias, workspace, output, LDC, op);
@@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> {
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb

output += 4;
@@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> {

template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 4, 4> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
static void run(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 4;
@@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> {
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4,
4>::postprocess(bias, workspace,
output, LDC, op);
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess(
bias, workspace, output, LDC, op);

output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace,
4, is_first_k, 4,
std::min<size_t>(N - n, 4));
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, workspace, 4, is_first_k, 4,
std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \
bias, workspace, output, LDC, op);
@@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> {
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb
output += B_INTERLEAVE;
cur_packB += K4;
@@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> {

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity)

void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_s8_4x4_nobias_identity::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_s8_4x4_nobias_identity::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
@@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const {
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity)

void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_s8_8x12_nobias_identity::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t);
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t);
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_s8_8x12_nobias_identity::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
@@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {

#endif

#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace); \
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \
dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \
dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \
}

#define DEFINE_OP(_Op) \
@@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#endif
#undef DEFINE_OP

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C);
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op( \
scale_A* scale_B, scale_A* scale_B, scale_C);
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#if MGB_ENABLE_DOT
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#endif
#undef DEFINE_OP



+ 25
- 26
dnn/src/aarch64/conv_bias/int8/strategy.h View File

@@ -20,43 +20,42 @@ namespace matmul {
*
* \name gemm_<type>_<block>_biasmode_nolinemode
*/
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16,
false, true,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity);
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4,
false, true,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity);
#endif

} // namespace matmul


+ 12
- 13
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -13,13 +13,13 @@
#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/quint8/algos.h"

#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "src/fallback/convolution/opr_impl.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp16/algos.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/fallback/convolution/opr_impl.h"

using namespace megdnn;
using namespace aarch64;
@@ -56,12 +56,10 @@ public:
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos()
const {
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const {
return m_matmul_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

};

const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
@@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)

SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
algos.insert(
algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
//! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
algos.insert(
algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
return std::move(algos);
}



+ 1
- 1
dnn/src/aarch64/conv_bias/opr_impl.h View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/utils.h"

namespace megdnn {
namespace aarch64 {


+ 36
- 37
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
size_t N = OH * OW;

#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
@@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0);
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
@@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
return {nullptr, {part0, part1, part2}};
}

void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(
const NCBKernParam& param, const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
@@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<true>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
img2col_stride<false>(
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
Workspace workspace(
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;

#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);

if (cpuinfo_has_arm_neon_dot()) {
@@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
#define DISPATCH_GEMM_STRATEGY( \
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN( \
megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \
} \
MIDOUT_END()

DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)


+ 6
- 8
dnn/src/aarch64/conv_bias/quint8/algos.h View File

@@ -12,8 +12,8 @@
#pragma once

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/opr_impl.h"

namespace megdnn {
namespace aarch64 {
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);

public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "QU8MATMUL"; }

bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
bool usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override {
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}


+ 93
- 97
dnn/src/aarch64/conv_bias/quint8/strategy.cpp View File

@@ -14,8 +14,8 @@
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h"

using namespace megdnn;
@@ -29,10 +29,10 @@ struct KernCaller;
#if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8, true> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
static void run(
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K,
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
@@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> {
size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B, zAB);
matmul_8x8x4::kern_8x8(
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB);

arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B, zAB);
matmul_8x8x4::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4), zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
@@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> {
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B, zAB);
matmul_8x8x4::kern_4x8(
packA, cur_packB, K, workspace, 8, is_first_k,
std::min<size_t>(M - m, 4), zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
@@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> {
}

for (; n < N; n += 4) {
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B,
zAB);
matmul_8x8x4::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A,
zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb

output += 4;
@@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> {

template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8, false> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
static void run(
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K,
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);

constexpr size_t A_INTERLEAVE = 8;
@@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> {
size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B);
matmul_8x8x8::kern_8x8(
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B);

arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>::
postprocess(bias, workspace, output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B);
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(N - n, 4), zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb


output += 4;
cur_packB += K4;
}
@@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> {
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B);
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, workspace, 8, is_first_k,
std::min<size_t>(M - m, 4), zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
@@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> {
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B);
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A,
zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#undef cb


output += 4;
cur_packB += K4;
}
@@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> {
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity)

void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_dot_nobias_identity::pack_A(
uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0,
ymax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin,
y0, ymax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(
outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_dot_nobias_identity::pack_B(
uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
xmax, k0, kmax);
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(
out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax,
k0, kmax);
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(
out, in, ldin, x0, xmax, k0, kmax);
}
}

@@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const {

#endif
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity)
void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr,
const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_u8_8x8_nodot_nobias_identity::pack_A(
dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax, zA);
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax, zA);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax, zA);
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA);
}
}

void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
void gemm_u8_8x8_nodot_nobias_identity::pack_B(
dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax, zB);
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax, zB);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax,
zB);
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB);
}
}

@@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32);
}

#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \
_OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \
dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \
zp_B); \
}

#define DEFINE_OP(_Op) \
@@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#undef DEFINE_OP

#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C);
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op( \
scale_A* scale_B, scale_A* scale_B, scale_C, zp_C);
#if MGB_ENABLE_DOT
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu,
FuseAddReluOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#endif
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity,
AddOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu,
FuseAddReluOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#undef DEFINE_OP

#undef KERN


+ 26
- 28
dnn/src/aarch64/conv_bias/quint8/strategy.h View File

@@ -16,46 +16,44 @@ namespace aarch64 {
namespace matmul {

#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4,
false, true,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish,
gemm_u8_8x8_dot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity);

#endif
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8,
false, true,
gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(
dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true,
gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish,
gemm_u8_8x8_nodot_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(
gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity);

} // namespace matmul
} // namespace aarch64


+ 4
- 4
dnn/src/aarch64/handle.cpp View File

@@ -11,11 +11,11 @@

#include "src/common/handle_impl.h"

#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/handle.h"
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/relayout/opr_impl.h"
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/warp_perspective/opr_impl.h"

namespace megdnn {
@@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective)
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop

} // namespace aarch64
} // namespace megdnn
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 10
- 12
dnn/src/aarch64/handle.h View File

@@ -14,20 +14,18 @@
namespace megdnn {
namespace aarch64 {

class HandleImpl: public arm_common::HandleImpl {
public:
HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64):
arm_common::HandleImpl::HandleImpl(computing_handle, type)
{}
class HandleImpl : public arm_common::HandleImpl {
public:
HandleImpl(
megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64)
: arm_common::HandleImpl::HandleImpl(computing_handle, type) {}

template <typename Opr>
std::unique_ptr<Opr> create_operator();
template <typename Opr>
std::unique_ptr<Opr> create_operator();
};

} // namespace aarch64
} // namespace megdnn
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen



+ 327
- 438
dnn/src/aarch64/matrix_mul/algos.cpp
File diff suppressed because it is too large
View File


+ 28
- 76
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -21,9 +21,7 @@ namespace aarch64 {

class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -35,8 +33,7 @@ public:
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
@@ -48,9 +45,7 @@ public:

class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32K4X16X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -61,9 +56,7 @@ public:

class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F32_MK4_4x16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -73,8 +66,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16)
};

class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv {
public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::AARCH64;
@@ -85,9 +77,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F16_K8X24X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -98,9 +88,7 @@ public:

class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_F16_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -115,12 +103,8 @@ public:
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
@@ -130,12 +114,8 @@ public:

class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
@@ -147,8 +127,7 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
@@ -163,9 +142,7 @@ public:

class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
@@ -178,9 +155,7 @@ public:

class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
@@ -192,9 +167,7 @@ public:

class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
@@ -207,9 +180,7 @@ public:

class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
@@ -222,8 +193,7 @@ public:
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
@@ -238,12 +208,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -257,12 +224,9 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -276,8 +240,7 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
@@ -292,9 +255,7 @@ public:

class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
@@ -306,9 +267,7 @@ public:

class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -321,12 +280,8 @@ public:
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
const char* name() const override {
return "AARCH64_QUINT8_K8X8X4_DOTPROD";
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
@@ -336,8 +291,7 @@ public:
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
@@ -352,9 +306,7 @@ public:
#endif
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;


+ 393
- 458
dnn/src/aarch64/matrix_mul/asm/common.h
File diff suppressed because it is too large
View File


+ 534
- 555
dnn/src/aarch64/matrix_mul/fp16/strategy.cpp
File diff suppressed because it is too large
View File


+ 4
- 4
dnn/src/aarch64/matrix_mul/fp16/strategy.h View File

@@ -16,11 +16,11 @@ namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false,
true, hgemm_8x24);
MEGDNN_REG_GEMM_STRATEGY(
dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1,
false, true, gemm_nopack_f16_8x8);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8);

} // namespace matmul
} // namespace aarch64


+ 21
- 20
dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

@@ -21,8 +21,9 @@ using namespace aarch64::matmul;

namespace {

void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x1(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
LDB *= sizeof(dt_float16);
asm volatile(
".arch armv8.2-a+fp16\n"
@@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
}

// Overview of register layout:
@@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
// |v23[0-7]| |v27[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x4(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! LDB means number of elements in one block in B. we will read 24 numbers
//! first. so minus 24 * 2 bytes here.
LDB = (LDB - 24) * sizeof(dt_float16);
@@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory");
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "cc", "memory");
}

// Overview of register layout:
@@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
// | v7[0-7]| |v31[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
void kern_8x8(
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112
//! here.
LDB = (LDB - 32) * sizeof(dt_float16);
@@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31", "cc", "memory");
}

} // anonymous namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8);

void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N,
const dt_float16*, void*, bool trA,
bool trB) const {
void gemm_nopack_f16_8x8::kern(
const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 8;


+ 13
- 11
dnn/src/aarch64/matrix_mul/fp32/common.h View File

@@ -17,21 +17,23 @@
namespace megdnn {
namespace aarch64 {

MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packA_n(
const float* A, float* Apacked, size_t M, size_t K, size_t LDA,
const float* alpha);

MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packA_t(
const float* A, float* Apacked, size_t M, size_t K, size_t LDA,
const float* alpha);

MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_packB_n(
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB);

MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_packB_t(
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB);

MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C,
size_t LDC, size_t M, size_t N, size_t K,
int type, const float* beta);
MEGDNN_NOINLINE void sgemm_kernel12x8(
const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N,
size_t K, int type, const float* beta);

} // namespace aarch64
} // namespace megdnn


+ 124
- 118
dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h View File

@@ -12,7 +12,6 @@
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"


namespace megdnn {
namespace aarch64 {
namespace matmul_general_4x16 {
@@ -39,8 +38,9 @@ namespace matmul_general_4x16 {
// +--+ - - - - +--------+--------+--------+--------+
//
// Accumulator
void kern_4x16(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain) {
void kern_4x16(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K,

"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9",
"x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K,
// +--+--+ - - - - +--------+
//
// Accumulator
void kern_4x4(const float* packA, const float* packB, int K, float* output,
int LDC, bool is_first_k, int m_remain, int n_remain) {
void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output,
STORE_LINE("6", "2") \
STORE_LINE("7", "3") \
"105:\n"
// clang-format on
asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"2: \n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"6:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1",
"x2", "x3", "x10", "cc", "memory");
// clang-format on
asm volatile(
// load accumulator C
"add x1, x0, %x[LDC]\n"
"add x2, x1, %x[LDC]\n"
"add x3, x2, %x[LDC]\n"
"cmp %w[is_first_k], #1\n"
"beq 1f\n" LOAD_C
"b 2f\n"
"1:\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"
"2: \n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"cmp %w[K], #0\n"
"beq 4f\n"
"3:\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"subs %w[K], %w[K], #1\n"
"bne 3b\n"
"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail
"ld1 {v1.4s}, [%[a_ptr]], 16\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"fmla v4.4s, v3.4s, v1.s[0]\n"
"fmla v5.4s, v3.4s, v1.s[1]\n"
"fmla v6.4s, v3.4s, v1.s[2]\n"
"fmla v7.4s, v3.4s, v1.s[3]\n"
"b 6f\n"
// odd tail
"5:\n"
"fmla v4.4s, v2.4s, v0.s[0]\n"
"fmla v5.4s, v2.4s, v0.s[1]\n"
"fmla v6.4s, v2.4s, v0.s[2]\n"
"fmla v7.4s, v2.4s, v0.s[3]\n"
"6:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10",
"cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
int ymax, int k0, int kmax) {
void sgemm_4x16_pack_A_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4);
constexpr int PACK_SIZE = 4*4;
constexpr int PACK_SIZE = 4 * 4;

int y = y0;
for (; y + 3 < ymax; y += 4) {
// printf("main loop pack_a_n %p \n",outptr);
// printf("main loop pack_a_n %p \n",outptr);
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
@@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0,
}
}

void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
void sgemm_4x16_pack_A_t(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize4 = (ksize << 2);
float* outptr_base = out;
@@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}

@@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0,
}
}

void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
void sgemm_4x16_pack_B_n(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize16 = ksize * 16;
int ksize4 = (ksize << 2);
@@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
auto outptr = outptr_base;
for (; x + 16 <= xmax; x += 16) {
auto outptr_interleave = outptr;
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize16;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}

@@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin,
}
}

void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
int y0, int ymax, int k0, int kmax) {
void sgemm_4x16_pack_B_t(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) {
float* outptr = out;
const float* inptr = in;
float zerobuff[4];
@@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,

int x = (kmax - k0);
for (; x > 3; x -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner,
64);
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64);
outptr_inner += 64;
}
for (; x > 0; x--) {
@@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin,
}
}

} // matmul_general_4x16
} // aarch64
} // megdnn
} // namespace matmul_general_4x16
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 54
- 60
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -306,14 +307,13 @@ struct matmul_general_8x12 {
"6:\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -348,9 +348,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -520,13 +520,12 @@ struct matmul_general_8x12 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -557,9 +556,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -717,13 +716,12 @@ struct matmul_general_8x12 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x10", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -754,9 +752,9 @@ struct matmul_general_8x12 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -895,20 +893,21 @@ struct matmul_general_8x12 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}

static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_A_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
float zerobuff[8];
std::memset(zerobuff, 0, sizeof(float) * 8);
constexpr int PACK_SIZE_32 = 4 * 8;
@@ -933,8 +932,9 @@ struct matmul_general_8x12 {
prefetch_2x(inptr7);
int x = (kmax - k0);
for (; x > 3; x -= 4) {
transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_8x4_1_s(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += PACK_SIZE_32;
}
for (; x > 0; x--) {
@@ -1004,8 +1004,8 @@ struct matmul_general_8x12 {
}
}

static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_A_t(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize8 = (ksize << 3);
int ksize4 = (ksize << 2);
@@ -1028,20 +1028,17 @@ struct matmul_general_8x12 {
auto outptr = outptr_base;
for (; x + 8 <= xmax; x += 8) {
auto outptr_interleave = outptr;
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize8;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4,
xmax - x);
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}
outptr_base += 4 * 8;
outptr_base4 += 4 * 4;
@@ -1071,8 +1068,8 @@ struct matmul_general_8x12 {
}
}

static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_B_n(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
@@ -1095,20 +1092,17 @@ struct matmul_general_8x12 {
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3,
outptr_interleave);
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}
if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4,
xmax - x);
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}
outptr_base += 12 * 4;
outptr_base4 += 4 * 4;
@@ -1138,8 +1132,8 @@ struct matmul_general_8x12 {
}
}

static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_B_t(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) {
float* outptr = out;
const float* inptr = in;
float zerobuff[12];
@@ -1172,9 +1166,9 @@ struct matmul_general_8x12 {
prefetch_2x(inptr11);
int x = (kmax - k0);
for (; x > 3; x -= 4) {
transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, inptr8, inptr9,
inptr10, inptr11, outptr);
transpose_12x4_1_s(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
outptr += 48;
}
for (; x > 0; x--) {


+ 34
- 37
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 {
"6:\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
}
@@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9",
"x10", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21",
"x22", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE


+ 34
- 37
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h View File

@@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 {
"6:\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
}
@@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20",
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23",
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int m_remain) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
@@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 {
"6:\n" STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[oddk] "+r"(oddk), [outptr] "+r"(outptr),
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain)
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2",
"x3", "x10", "cc", "memory");
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3",
"x10", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE


+ 28
- 27
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -307,10 +308,10 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "cc", "memory");
}

// Overview of register layout:
@@ -340,9 +341,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -500,8 +501,8 @@ struct matmul_mk4_8x12 {
[output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "cc", "memory");

#undef LOAD_C
#undef STORE_C
@@ -531,8 +532,9 @@ struct matmul_mk4_8x12 {
//
// Accumulator

static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -669,9 +671,9 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"cc", "memory");
}

// Overview of register layout:
@@ -697,9 +699,9 @@ struct matmul_mk4_8x12 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -818,15 +820,15 @@ struct matmul_mk4_8x12 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");

#undef LOAD_C
#undef STORE_C
}

static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void sgemm_8x12_pack_A(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int PACK_SIZE_32 = 4 * 8;
@@ -855,8 +857,8 @@ struct matmul_mk4_8x12 {
}
}

static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax) {
static void sgemm_8x12_pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
float tmpbuff[16] = {0.0f};

@@ -886,8 +888,7 @@ struct matmul_mk4_8x12 {
outptr += ksize4;
}
if (x < xmax) {
std::memcpy(tmpbuff, inptr,
sizeof(float) * (xmax - x) * PACK_C_SIZE);
std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
auto outptr_interleave = outptr;
const float* tmp_ptr = &tmpbuff[0];
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave);


+ 23
- 22
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
}

// Overview of register layout:
@@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 {
[output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "x8", "x9", "x10", "cc", "memory");

#undef LOAD_C
#undef STORE_C
@@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 {
//
// Accumulator

static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x8", "x9", "x10", "x11", "x12", "cc", "memory");
}

// Overview of register layout:
@@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");

#undef LOAD_C
#undef STORE_C


+ 23
- 22
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h View File

@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+--------+--------+
//
// Accumulator
static void kern_8x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_8x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [output1] "+r"(output1)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10",
"x11", "x12", "x13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory");
}

// Overview of register layout:
@@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_8x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_8x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
float* output0 = output;
@@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 {
[output0] "+r"(output0), [output1] "+r"(output1),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "x8", "x9", "x10", "cc", "memory");

#undef LOAD_C
#undef STORE_C
@@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 {
//
// Accumulator

static void kern_4x12(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k) {
static void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1",
"x8", "x9", "x10", "x11", "x12", "cc", "memory");
}

// Overview of register layout:
@@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 {
// +--+ --- - +--------+
//
// Accumulator
static void kern_4x4(const float* packA, const float* packB, int K,
float* output, int LDC, bool is_first_k,
int n_remain) {
static void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int n_remain) {
MEGDNN_MARK_USED_VAR(LDC);
const float* a_ptr = packA;
const float* b_ptr = packB;
@@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 {
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk),
[output0] "+r"(output0), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory");

#undef LOAD_C
#undef STORE_C


+ 83
- 86
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp View File

@@ -10,6 +10,7 @@
* implied.
*/

#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h"
@@ -17,44 +18,40 @@
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"


using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);

void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
void sgemm_4x16::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
void sgemm_4x16::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_4x16::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
matmul_general_4x16::kern_4x16(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K16;
}
@@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);

void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
void sgemm_8x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
void sgemm_8x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
if (transpose_B) {
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

template <typename gemm_class>
static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
static inline void sgemm_8x12_helper(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
@@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
}

for (; n < N; n += 4) {
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_class::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
gemm_class::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
gemm_class::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB,
}
}

void sgemm_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_8x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
#if !MGB_ENABLE_CPUINFO
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
#ifdef __IN_TEE_ENV__
arch = cpuinfo_uarch_unknown;
#endif
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_8x12_helper<matmul_general_8x12_a53>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_8x12_helper<matmul_general_8x12_a55>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else {
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_8x12_helper<matmul_general_8x12>(
packA, packB, M, N, K, C, LDC, is_first_k);
}
#endif
}

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);

void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose_A) const {
void sgemm_mk4_8x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A");
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax);
}

void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose_B) const {
void sgemm_mk4_8x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B");
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

template <typename gemm_name>
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
static inline void sgemm_mk4_8x12_helper(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
@@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
}

for (; n < N; n += 4) {
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_name::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
@@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
cur_packB += K12;
}
for (; n < N; n += 4) {
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
gemm_name::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K4;
}
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
void sgemm_mk4_8x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
#if !MGB_ENABLE_CPUINFO
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
#ifdef __IN_TEE_ENV__
arch = cpuinfo_uarch_unknown;
#endif
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(
packA, packB, M, N, K, C, LDC, is_first_k);
} else {
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(
packA, packB, M, N, K, C, LDC, is_first_k);
}
#endif
}


+ 5
- 8
dnn/src/aarch64/matrix_mul/fp32/strategy.h View File

@@ -15,17 +15,14 @@
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);

MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16);

MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false,
sgemm_mk4_8x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16);

} // namespace matmul
} // namespace aarch64


+ 19
- 23
dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp View File

@@ -20,8 +20,8 @@ using namespace aarch64::matmul;

namespace {

void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x1(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
LDB *= sizeof(float);
asm volatile(
"subs %w[K], %w[K], #4\n"
@@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc",
"memory");
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory");
}

// Overview of register layout:
@@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+ - - - - -+--------+
// Accumulator

void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x4(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12
//! here.
LDB = (LDB - 12) * sizeof(float);
@@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "cc", "memory");
}

// Overview of register layout:
@@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+ - - - - -+--------+
// Accumulator

void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
float* output) {
void kern_4x8(
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) {
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24
//! here.
LDB = (LDB - 24) * sizeof(float);
@@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17",
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
"v27", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc",
"memory");
}

// Overview of register layout:
@@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K,
// +--------+
// Accumulator

void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K,
float* output) {
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) {
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56
//! here.
LDB = (LDB - 56) * sizeof(float);
@@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
}

} // namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16);

void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B,
size_t LDB, float* C, size_t LDC, size_t M,
size_t K, size_t N, const float*, void*, bool trA,
bool trB) const {
void sgemm_nopack_4x16::kern(
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC,
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const {
constexpr static size_t MB = 4;
constexpr static size_t KB = 4;
constexpr static size_t NB = 16;


+ 99
- 96
dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h View File

@@ -46,8 +46,9 @@ namespace matmul_12x8x1 {
* Accumulator
*/

static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_12x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
"stp q25, q26, [x9]\n"
"stp q27, q28, [x10]\n"
"stp q29, q30, [x11]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
:
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11",
"cc", "memory");
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
@@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator
*/

static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
"stp q17, q18, [x5]\n"
"stp q19, q20, [x6]\n"
"stp q21, q22, [x7]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
:
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4",
"x5", "x6", "x7", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
@@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator
*/

static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
:
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"cc", "memory");
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc",
"memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
@@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K,
* Accumulator
*/

static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_12x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
[outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9),
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11),
[x0] "+r"(x0), [n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9),
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "cc", "memory");
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K,
* Accumulator
*/

static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
:
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"cc", "memory");
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K,
* Accumulator
*/

static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
const int16_t* a_ptr = packA;
const int16_t* b_ptr = packB;

@@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain), [x1] "+r"(x1),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1),
[n_remain] "+r"(n_remain)
:
: "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory");
@@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K,
#undef STORE_C
}

static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s16_12x8x1_pack_A_n(
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int16_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int16_t) * 4);

@@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,

int K = kmax - k0;
for (; K > 3; K -= 4) {
interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10,
inptr11, outptr);
interleave_12x1_4_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
}

if (K > 0) {
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11,
outptr, 1, K);
interleave_12(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr, 1, K);
}
}

@@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,

int K = kmax - k0;
for (; K > 7; K -= 8) {
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x1_8_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 1, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 1, K);
}
}

@@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr,
}
}

static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in,
int ldin, int x0, int xmax,
int k0, int kmax) {
static void gemm_s16_12x8x1_transpose_pack_A_n(
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize4 = ksize * 4;
const int ksize8 = ksize4 * 2;
@@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in,
}
}

static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s16_12x8x1_pack_B_n(
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize4 = ksize * 4;
const int ksize8 = ksize4 * 2;
@@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin,
}
}

static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
const int16_t* inptr, int ldin,
int y0, int ymax, int k0,
int kmax) {
static void gemm_s16_12x8x1_transpose_pack_B_n(
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int16_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int16_t) * 4);

@@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,

int K = kmax - k0;
for (; K > 7; K -= 8) {
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x1_8_h(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 1, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 1, K);
}
}

@@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;


+ 35
- 36
dnn/src/aarch64/matrix_mul/int16/strategy.cpp View File

@@ -22,39 +22,37 @@ using namespace aarch64::matmul;
///////////////////////// gemm_s16_12x8x1 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1);

void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s16_12x8x1::pack_A(
dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin,
y0, ymax, k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax,
k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s16_12x8x1::pack_B(
dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int16 &&
C_dtype.enumv() == DTypeEnum::Int32),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s16_12x8x1::kern(
const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int16 &&
C_dtype.enumv() == DTypeEnum::Int32),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
size_t n = 0;
const dt_int16* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_12x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
const dt_int16* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB,
const dt_int16* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
matmul_12x8x1::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_12x8x1::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}


+ 4
- 4
dnn/src/aarch64/matrix_mul/int16/strategy.h View File

@@ -16,11 +16,11 @@ namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true,
gemm_s16_12x8x1);
MEGDNN_REG_GEMM_STRATEGY(
dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1);

MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false,
true, gemm_nopack_s16_8x8);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8);

} // namespace matmul
} // namespace aarch64


+ 23
- 21
dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

@@ -20,8 +20,9 @@ using namespace aarch64::matmul;

namespace {

void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x1(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here.
LDB *= sizeof(dt_int16);
@@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
}

// Overview of register layout:
@@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
// | v31[0-7]| |v23[0-3]|
// +---------+ +--------+
// Accumulator
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x4(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24
//! here.
LDB = (LDB - 24) * sizeof(dt_int16);
@@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
"memory");
}

// Overview of register layout:
@@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
// | v7[0-7]| |v30[0-3]|v31[0-3]|
// +--------+ +--------+--------+
// Accumulator
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
void kern_8x8(
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
dt_int32* output) {
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48
//! here.
LDB = (LDB - 48) * sizeof(dt_int16);
@@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K,
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
}

} // anonymous namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8);

void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B,
size_t LDB, dt_int32* C, size_t LDC, size_t M,
size_t K, size_t N, const dt_int32*, void*,
bool trA, bool trB) const {
void gemm_nopack_s16_8x8::kern(
const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C,
size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 8;


+ 121
- 103
dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h View File

@@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 {
* Accumulator
*/

static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void s4_kern_8x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA;
@@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"
"ld1 {v17.8b}, [%[b_ptr]], 8\n"

"dup v8.8b,v20.b[8]\n"
@@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
STORE_C

:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain),
[n_remain] "+r"(
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");

#undef LOAD_LINE
#undef LOAD_C
@@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void s4_kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
// clang-format off
// clang-format off

#define LOAD_C_8 \
"ld1 {v24.8h}, [x0], #16\n" \
@@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"st1 {v28.8h}, [x4], #16\n" \
"st1 {v29.8h}, [x5], #16\n" \
"st1 {v30.8h}, [x6], #16\n" \
"st1 {v31.8h}, [x7], #16\n" \
"st1 {v31.8h}, [x7], #16\n"

// clang-format on
// clang-format on
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
@@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n"
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n"
"1:\n"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
// "ld1 {v20.16b}, [%[a_ptr]],#16\n"
// "ld1 {v21.16b}, [%[a_ptr]],#16\n"
"dup v0.8b,v20.b[0]\n"
"ld1 {v22.16b}, [%[a_ptr]],#16\n"
"dup v1.8b,v20.b[1]\n"
@@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"dup v5.8b,v20.b[5]\n"
"dup v6.8b,v20.b[6]\n"
"dup v7.8b,v20.b[7]\n"

"dup v8.8b,v20.b[8]\n"
"smlal v24.8h, v0.8b, v16.8b\n"
@@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K,
STORE_C_8

:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain),
[n_remain] "+r"(
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");

#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
#undef STORE_C
}
//packa
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
// packa
static void gemm_s4x4x16_8x8x8_transpose_pack(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[8];
int8_t tmpbuff0[8];
int8_t tmpbuff1[8];
@@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
prefetch_2x(inptr5);
prefetch_2x(inptr6);
prefetch_2x(inptr7);
int K = (kmax - k0)/2;
int K = (kmax - k0) / 2;
//! read 4 * 16 in each row
for (; K > 3; K -= 4) {
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
std::memcpy(tmpbuff0, inptr0, K);
std::memcpy(tmpbuff1, inptr1, K);
std::memcpy(tmpbuff2, inptr2, K);
std::memcpy(tmpbuff3, inptr3, K);
std::memcpy(tmpbuff4, inptr4, K);
std::memcpy(tmpbuff5, inptr5, K);
std::memcpy(tmpbuff6, inptr6, K);
std::memcpy(tmpbuff7, inptr7, K);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
@@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}
}
for (; y < ymax; y += 8) {
@@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
const int8_t* inptr6 = inptr5 + ldin;
const int8_t* inptr7 = inptr6 + ldin;

int K = (kmax - k0)/2;
int K = (kmax - k0) / 2;
//! read 4 * 16 in each row
for (; K > 3; K -= 4) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
megdnn_assert(0);
}
}
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}
if (K > 0) {
if (y + 7 >= ymax) {
switch (y + 7 - ymax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
}
}

std::memcpy(tmpbuff0,inptr0,K);
std::memcpy(tmpbuff1,inptr1,K);
std::memcpy(tmpbuff2,inptr2,K);
std::memcpy(tmpbuff3,inptr3,K);
std::memcpy(tmpbuff4,inptr4,K);
std::memcpy(tmpbuff5,inptr5,K);
std::memcpy(tmpbuff6,inptr6,K);
std::memcpy(tmpbuff7,inptr7,K);
std::memcpy(tmpbuff0, inptr0, K);
std::memcpy(tmpbuff1, inptr1, K);
std::memcpy(tmpbuff2, inptr2, K);
std::memcpy(tmpbuff3, inptr3, K);
std::memcpy(tmpbuff4, inptr4, K);
std::memcpy(tmpbuff5, inptr5, K);
std::memcpy(tmpbuff6, inptr6, K);
std::memcpy(tmpbuff7, inptr7, K);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
@@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in
inptr5 = tmpbuff5;
inptr6 = tmpbuff6;
inptr7 = tmpbuff7;
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4,
inptr5, inptr6, inptr7, outptr);
transpose_4x8_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}
}
}
//packb
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
// packb
static void gemm_s4x4x16_8x8x8_interleave_pack(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[8];
int8_t tmpbuff0[8];
int8_t tmpbuff1[8];
@@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8);
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8);
const int ksize = kmax - k0;
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4
const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4
int8_t* outptr = out;
int8_t* outptr_interleave = nullptr;

@@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
int8_t* outptr_inner = outptr;
for (; x + 3 < xmax; x += 4) {
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8;
}

if (x < xmax) {
int remainx = xmax - x;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
std::memcpy(tmpbuff0, inptr0, remainx);
std::memcpy(tmpbuff1, inptr1, remainx);
std::memcpy(tmpbuff2, inptr2, remainx);
std::memcpy(tmpbuff3, inptr3, remainx);
std::memcpy(tmpbuff4, inptr4, remainx);
std::memcpy(tmpbuff5, inptr5, remainx);
std::memcpy(tmpbuff6, inptr6, remainx);
std::memcpy(tmpbuff7, inptr7, remainx);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
@@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
inptr7 = tmpbuff7;

outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8;
}
outptr += 64;
@@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
break;
}
outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8;
}
if (x < xmax) {
@@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
}
int remainx = xmax - x;
outptr_interleave = outptr_inner;
std::memcpy(tmpbuff0,inptr0,remainx);
std::memcpy(tmpbuff1,inptr1,remainx);
std::memcpy(tmpbuff2,inptr2,remainx);
std::memcpy(tmpbuff3,inptr3,remainx);
std::memcpy(tmpbuff4,inptr4,remainx);
std::memcpy(tmpbuff5,inptr5,remainx);
std::memcpy(tmpbuff6,inptr6,remainx);
std::memcpy(tmpbuff7,inptr7,remainx);
std::memcpy(tmpbuff0, inptr0, remainx);
std::memcpy(tmpbuff1, inptr1, remainx);
std::memcpy(tmpbuff2, inptr2, remainx);
std::memcpy(tmpbuff3, inptr3, remainx);
std::memcpy(tmpbuff4, inptr4, remainx);
std::memcpy(tmpbuff5, inptr5, remainx);
std::memcpy(tmpbuff6, inptr6, remainx);
std::memcpy(tmpbuff7, inptr7, remainx);
inptr0 = tmpbuff0;
inptr1 = tmpbuff1;
inptr2 = tmpbuff2;
@@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in,
inptr7 = tmpbuff7;

outptr_interleave = outptr_inner;
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x4_1_b_with_shift(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr_inner += ksize8;
}
}
}

} // namespace matmul_4x4x16
} // namespace matmul_s4_4x4x16
} // namespace aarch64
} // namespace megdnn


// vim: syntax=cpp.doxygen

+ 33
- 34
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp View File

@@ -10,9 +10,9 @@
* implied.
*/

#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h"
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
@@ -23,39 +23,38 @@ using namespace aarch64::matmul;

// ===========================gemm_s4x4x16_s4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8);
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s4x4x16_s4_8x8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(
out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(
out, in, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s4x4x16_s4_8x8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(
out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0,
kmax);
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(
out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::QuantizedS4 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s4x4x16_s4_8x8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::QuantizedS4 &&
C_dtype.enumv() == DTypeEnum::QuantizedS16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
matmul_s4_4x4x16::s4_kern_8x8(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
B_INTERLEAVE);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_s4_4x4x16::s4_kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K8;
}
@@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_s4_4x4x16::s4_kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K8;
}
@@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
}
}


// vim: syntax=cpp.doxygen

+ 2
- 2
dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h View File

@@ -17,8 +17,8 @@ namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s4x4x16_s4_8x8x8);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8);

} // namespace matmul
} // namespace aarch64


+ 57
- 39
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h View File

@@ -51,8 +51,9 @@ namespace matmul_4x4x16 {
* Accumulator
*/

static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 16;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
);
}

static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
int m_remain, int n_remain) {
static void kern_4x4_remain(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
megdnn_assert(K > 0);
K /= 16;
const int8_t* a_ptr = packA;
@@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,

STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0, int kmax) {
static void gemm_s8_4x4_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
}
}

static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_4x4_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
}

transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
transpose_4x16_1_b_helper(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4;
}

@@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;


+ 163
- 113
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h View File

@@ -42,8 +42,9 @@ namespace matmul_8x8x8 {
* Accumulator
*/

static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"stp q18, q19, [x5]\n"
"stp q20, q21, [x6]\n"
"stp q22, q23, [x7]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v26", "v27", "x1",
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"cc", "memory");
}

/**
@@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator
*/

static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t n_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
* Accumulator
*/

static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator
*/

static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[x1] "+r"(x1), [m_remain] "+r"(m_remain),
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc",
@@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
int y0, int ymax, int k0, int kmax) {
static void gemm_s8_8x8_pack_A_n(
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,

int K = kmax - k0;
for (; K > 15; K -= 16) {
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x8_2_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
}
}

@@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
}
}

static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax) {
static void gemm_s8_8x8_transpose_pack_A_n(
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
megdnn_assert(0);
}
}
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += ksize8;
}

@@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
}
}

transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, 4);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, 4);
outptr += ksize4;
}

@@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
}
}

transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, xmax - x);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, xmax - x);
}

outptr_base += 8 * 8;
@@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
}
}

static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x8_pack_B_n(
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
}
}
outptr_interleave = outptr;
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr += ksize8;
}

@@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
}

outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, 4);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, 4);
outptr += ksize4;
}

@@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
}

outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, xmax - x);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, xmax - x);
}

outptr_base += 8 * 8;
@@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
}
}

static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x8_transpose_pack_B_n(
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
constexpr int interleave4 = 32;
@@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,

int K = kmax - k0;
for (; K > 7; K -= 8) {
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += interleave8;
}

if (K > 0) {
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
outptr += interleave8;
}
}
@@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;


+ 47
- 41
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h View File

@@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 {
* Accumulator
*/

static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output,
bool is_first_k) {
K = div_ceil(K, 16);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"6:\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[k] "+r"(K), [output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
}

static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, bool is_first_k, size_t remain_n) {
static void kern_4x4_remain(
const int8_t* packA, const int8_t* packB, int K, int32_t* output,
bool is_first_k, size_t remain_n) {
K = div_ceil(K, 16);
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K,

"7:\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k),
[k] "+r"(K), [output] "+r"(output)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n),
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"cc", "memory");
}

static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_mk4_s8_4x4_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)}
int8_t zerobuff[4][64];
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4);
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
"mk4 matmul with k is not times of 4");
megdnn_assert(
ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(
kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0,
"mk4 matmul with k is not times of 4");
size_t roundk = round_up(kmax - k0, 16);
size_t out_offset = roundk * 4;
int y = y0;
@@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
prefetch_2x(inptr3);
int K = kmax - k0;
for (; K > 15; K -= 16) {
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
transpose_interleave_4x4_4_b(
inptr0, inptr1, inptr2, inptr3, output, out_offset);
output += 64;
}
if (K > 0) {
@@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
inptr1 = zerobuff[1];
inptr2 = zerobuff[2];
inptr3 = zerobuff[3];
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output,
out_offset);
transpose_interleave_4x4_4_b(
inptr0, inptr1, inptr2, inptr3, output, out_offset);
output += 64;
}
}
@@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
}
}

static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_mk4_s8_4x4_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int32_t zerobuff[4];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
const int ICB = (ksize) / 4;
const int ksize4 = round_up<int>(ICB, 4) * 4;
int32_t* outptr = reinterpret_cast<int32_t*>(out);
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
"mk4 matmul with k is not times of 4");
megdnn_assert(
kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0,
"mk4 matmul with k is not times of 4");

int k = k0 / 4;
for (; k + 3 < ICB; k += 4) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 =
@@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
outptr += 4 * 4;
}
if (k < ICB) {
const int32_t* inptr0 =
reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0);
const int32_t* inptr1 =
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0);
const int32_t* inptr2 =
@@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
if (k + 3 >= ICB) {
switch (k + 3 - ICB) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
if (k + 3 >= ICB) {
switch (k + 3 - ICB) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
}
}

} // namespace matmul_4x4x16
} // namespace matmul_mk4_4x4x16
} // namespace aarch64
} // namespace megdnn



+ 77
- 82
dnn/src/aarch64/matrix_mul/int8/strategy.cpp View File

@@ -24,20 +24,19 @@ using namespace aarch64::matmul;
///////////////////////// gemm_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4);

void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_4x4::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
@@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
}
}

void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4,
std::min<size_t>(N - n, 4));
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, output, LDC, is_first_k, 4,
std::min<size_t>(N - n, 4));
output += B_INTERLEAVE;
cur_packB += K4;
}
@@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
///////////////////////// gemm_mk4_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4);

void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose A");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
void gemm_mk4_s8_4x4::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
megdnn_assert(
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax);
}

void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
megdnn_assert(!transpose,
"the gemm_mk4_s8_4x4 strategy is not support transpose B");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0,
kmax);
void gemm_mk4_s8_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B");
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_mk4_s8_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output,
is_first_k);
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k);
output += B_INTERLEAVE * 4;
cur_packB += K4;
}

if (n < N) {
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output,
is_first_k, N - n);
matmul_mk4_4x4x16::kern_4x4_remain(
packA, cur_packB, K, output, is_first_k, N - n);
}

packA += K4;
}
}


///////////////////////// gemm_s8_8x8 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8);

void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_8x8::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax);
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(
outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax);
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}


+ 6
- 6
dnn/src/aarch64/matrix_mul/int8/strategy.h View File

@@ -16,14 +16,14 @@ namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true,
gemm_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4);

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false,
gemm_mk4_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4);

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
gemm_s8_8x8);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8);

} // namespace matmul
} // namespace aarch64


+ 62
- 59
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h View File

@@ -52,8 +52,9 @@ namespace matmul_8x12x4 {

#if 1
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
}
#else
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
"stp q15, q23, [%[outptr7]]\n"
"str q31, [%[outptr7], #32]\n"

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0),
[a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0),
[b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC),
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1),
[a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1),
[b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7)
:
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "cc", "memory");
}

#endif
@@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain) {
static void kern_4x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,

"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain),
[LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0),
[b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a),
[b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "=r"(x0)
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0),
[a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2),
[b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "=r"(x0)
:
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "memory", "cc");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "memory", "cc");

#undef LOAD_LINE
#undef LOAD_C
@@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0)
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4),
[outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
[x0] "=r"(x0)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory",
"cc");
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc");

#undef LOAD_LINE
#undef LOAD_C
@@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
//
// Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
@@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0),
[a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0),
[b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
[outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1)
:
: "v4", "v5", "v6", "v7", "memory", "cc");
@@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x12_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int K = kmax - k0;
//! read 8 * 4 in each row
for (; K > 15; K -= 16) {
interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x4_4_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, K);
}
}
for (; y < ymax; y += 4) {
@@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
}
}

static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x12_pack_A_t(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin,
}
}

static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8_8x12_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
}

static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8_8x12_pack_B_t(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr,
int K = kmax - k0;
//! read 12 * 4 in each row
for (; K > 15; K -= 16) {
interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10,
inptr11, outptr);
interleave_12x4_4_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr);
}

if (K > 0) {
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11,
outptr, 4, K);
interleave_12(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
inptr8, inptr9, inptr10, inptr11, outptr, 4, K);
}
}
for (; y < ymax; y += 4) {


+ 30
- 32
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h View File

@@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 {
// Accumulator

MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_8x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// Accumulator

MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) {
static void kern_4x12(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
"stp q16, q17, [%[outptr0], #128]\n"
"stp q18, q19, [%[outptr0], #160]\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a),
[b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a),
[b1a] "=w"(b1a), [b2a] "=w"(b2a)
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1),
[b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a)
:
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
"v17", "v18", "v19", "memory", "cc");
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
"v19", "memory", "cc");
}

// Overview of register layout:
@@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// Accumulator

MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
[x0] "=r"(x0)
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0)
:
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory",
"cc");
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc");

#undef LOAD_LINE
#undef LOAD_C
@@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// Accumulator

MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
bool is_first_k, int n_remain) {
K /= 4;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
@@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,

"4:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k),
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a),
[x0] "=r"(x0)
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a),
[b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0)
:
: "v4", "v5", "v6", "v7", "memory", "cc");

@@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0,
"mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0,
"mk4 matmul with k is not times of 4");
static void gemm_mk4_s8_8x12_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4");
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4");
int y = y0;
int start_y = y0 / 4;
for (; y + 7 < ymax; y += 8, start_y += 2) {
@@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr,
interleave_2x4_4_b(inptr0, inptr1, outptr);
}
}
for (; y + 3 < ymax; y += 4, start_y ++) {
for (; y + 3 < ymax; y += 4, start_y++) {
int K = kmax - k0;
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2);
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4);
}
}

static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_mk4_s8_8x12_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
const int ksize = kmax - k0;
const int ksize12 = ksize * 12;
const int ksize4 = ksize * 4;


+ 59
- 59
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp View File

@@ -12,10 +12,10 @@
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace aarch64;
@@ -24,20 +24,19 @@ using namespace aarch64::matmul;
/* ====================== gemm_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12);

void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8_8x12::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const {
void gemm_s8_8x12::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
@@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
}
}

void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8_8x12::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());

MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
@@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4));
matmul_8x12x4::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x12x4::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
/* ====================== gemm_mk4_s8_8x12 ===========================*/
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12);

void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0,
kmax);
void gemm_mk4_s8_8x12::pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
megdnn_assert(
!transpose, "matrix mul mk4 with transposed matrix A is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax);
}

void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported");
void gemm_mk4_s8_8x12::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
megdnn_assert(
!transpose, "matrix mul mk4 with transposed matrix B is not supported");
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int32* C,
size_t LDC, bool is_first_k, const dt_int32*,
dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_mk4_s8_8x12::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int32) ||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 &&
C_dtype.enumv() == DTypeEnum::QuantizedS32)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());

MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
@@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_mk4_8x12x4::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}
@@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB,
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += (B_INTERLEAVE << 2);
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
matmul_mk4_8x12x4::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 16;
cur_packB += K4;
}


+ 5
- 5
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h View File

@@ -16,14 +16,14 @@ namespace megdnn {
namespace aarch64 {
namespace matmul {

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12);

MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true,
gemm_mk4_s8_8x12);
MEGDNN_REG_GEMM_STRATEGY(
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12);

} // namespace aarch64
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 54
- 38
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h View File

@@ -34,9 +34,9 @@ namespace matmul_4x4x16 {
*
* Accumulator
*/
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 16;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
// Store back into memory
STORE_C

: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2",
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3",
"v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v26", "v27", "v28", "v29", "v30", "v31");

#undef LOAD_LINE
#undef LOAD_C
@@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_4x4_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
}
}

static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8x8x16_4x4_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
}

transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3,
inptr4, inptr5, inptr6, inptr7,
outptr_inner);
transpose_4x16_1_b_helper(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_inner);
outptr_inner += ksize4;
}

@@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (remain >= 0) {
switch (remain) {
case 7:
inptr0 = zerobuff; MEGDNN_FALLTHRU
inptr0 = zerobuff;
MEGDNN_FALLTHRU
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;


+ 159
- 111
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h View File

@@ -42,8 +42,9 @@ namespace matmul_8x8x8 {
*
* Accumulator
*/
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"bne 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5",
"x6", "x7", "cc", "memory");
#undef LOAD_LINE
#undef LOAD_C
#undef STORE_LINE
@@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
* Accumulator
*/

static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k,
size_t n_remain) {
static void kern_8x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t n_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
[n_remain] "+r"(n_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
* Accumulator
*/

static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k,
size_t m_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t m_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 2b\n"

"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
[m_remain] "+r"(m_remain)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
[x0] "+r"(x0), [m_remain] "+r"(m_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc",
"memory");
@@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
*
* Accumulator
*/
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, size_t m_remain,
size_t n_remain) {
static void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, size_t m_remain, size_t n_remain) {
K /= 8;
const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"3:\n" STORE_C
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr),
[K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC),
[x0] "+r"(x0), [m_remain] "+r"(m_remain),
[n_remain] "+r"(n_remain)
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1",
"cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc",
"memory");

#undef LOAD_LINE
#undef LOAD_C
@@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C
}

static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_8x8_pack_A_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,

int K = kmax - k0;
for (; K > 15; K -= 16) {
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
interleave_8x8_2_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
}

if (K > 0) {
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
}
}
for (; y < ymax; y += 4) {
@@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr,
}
}

static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax,
int k0, int kmax) {
static void gemm_s8x8x16_8x8_transpose_pack_A_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);

@@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
megdnn_assert(0);
}
}
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += ksize8;
}

@@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
}
}

transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, 4);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, 4);
outptr += ksize4;
}

@@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
}
}

transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 4, xmax - x);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 4, xmax - x);
}

outptr_base += 8 * 8;
@@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in,
}
}

static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax) {
static void gemm_s8x8x16_8x8_pack_B_n(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
const int ksize = kmax - k0;
@@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
}
outptr_interleave = outptr;
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr_interleave);
interleave_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave);
outptr += ksize8;
}

@@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}

outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, 4);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, 4);
outptr += ksize4;
}

@@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
if (k + 7 >= kmax) {
switch (k + 7 - kmax) {
case 6:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 5:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 4:
inptr3 = zerobuff; MEGDNN_FALLTHRU
inptr3 = zerobuff;
MEGDNN_FALLTHRU
case 3:
inptr4 = zerobuff; MEGDNN_FALLTHRU
inptr4 = zerobuff;
MEGDNN_FALLTHRU
case 2:
inptr5 = zerobuff; MEGDNN_FALLTHRU
inptr5 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr6 = zerobuff; MEGDNN_FALLTHRU
inptr6 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr7 = zerobuff;
break;
@@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}

outptr_interleave = outptr;
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr_interleave, 4, xmax - x);
interleave_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr_interleave, 4, xmax - x);
}

outptr_base += 8 * 8;
@@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
}
}

static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
const dt_int8* inptr, int ldin,
int y0, int ymax, int k0,
int kmax) {
static void gemm_s8x8x16_8x8_transpose_pack_B_n(
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
int8_t zerobuff[16];
std::memset(zerobuff, 0, sizeof(int8_t) * 16);
constexpr int interleave4 = 32;
@@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,

int K = kmax - k0;
for (; K > 7; K -= 8) {
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
inptr6, inptr7, outptr);
transpose_8x8_1_b(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr);
outptr += interleave8;
}

if (K > 0) {
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
inptr7, outptr, 8, K);
transpose_8(
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
outptr, 8, K);
outptr += interleave8;
}
}
@@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
@@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr,
if (y + 3 >= ymax) {
switch (y + 3 - ymax) {
case 2:
inptr1 = zerobuff; MEGDNN_FALLTHRU
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff; MEGDNN_FALLTHRU
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;


+ 34
- 41
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h View File

@@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 {
* Accumulator
*/
// clang-format on
static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k,
int remain_n) {
static __attribute__((noinline)) void kern_16x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4;
const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
"6:\n" STORE_C

"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");

#undef STORE_C
#undef STORE_LINE
@@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA,
* Accumulator
*/
// clang-format on
static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k, int remain_n) {
static __attribute__((noinline)) void kern_8x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4;
const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
"6:\n" STORE_C

"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory");

#undef STORE_C
#undef STORE_LINE
@@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA,
* Accumulator
*/
// clang-format on
static __attribute__((noinline)) void kern_4x12(const int16_t* packA,
const int8_t* packB, int K,
int16_t* output, int LDC,
bool is_first_k, int remain_n) {
static __attribute__((noinline)) void kern_4x12(
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int remain_n) {
K /= 4;
const int16_t* a_ptr = packA;
const int8_t* b_ptr = packB;
@@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA,
"6:\n" STORE_C

"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k),
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC),
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3",
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory");

#undef STORE_C
#undef STORE_LINE
}

static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
static void gemm_s8x8x16_mk4_16x12_pack_A(
dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 16;
@@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr,
}
}

static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_16x12_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");

constexpr int pack_n = 12;


+ 18
- 22
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h View File

@@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 {
*/

// clang-format on
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool, int remain_n) {
static inline void kern_4x4(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool,
int remain_n) {
K = div_ceil(K, 8);
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
@@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
"7:\n" STORE_C

"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk),
[LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc",
"memory");

#undef STORE_C
#undef STORE_LINE
@@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) {
vst1_s8(outptr + 3 * 8, in0.val[3]);
}

static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2,
dt_int8* outptr) {
static inline void interleve_8x4_b(
const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vld1q_s8(inptr2);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}

static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vdupq_n_s8(0);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}

static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
int ldin, int m0, int mmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_4x4x8_pack_A(
dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 4;
@@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
}
}

static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_4x4x8_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");

constexpr int pack_n = 4;


+ 48
- 55
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h View File

@@ -18,7 +18,6 @@ namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_8x8x8 {


/**
* Overview of register layout:
*
@@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 {
* | v16 | | v28 |
* | v17 | | v29 |
* | v16 | | v30 |
* | v17 | | v31 |
* | v17 | | v31 |
* +--------+ - - - - +---------------------------------+
*
* Accumulator
*/
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_8x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off
#define LOAD_C_8 \
"ld1 {v0.8h}, [x0], #16\n" \
@@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
// clang-format on
}

static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_8x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;
const int8_t* b_ptr = packA;
// clang-format off
// clang-format off
register int16_t* outptr asm("x0") = output;
asm volatile(
"add x1, x0, %x[LDC]\n"
@@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"cbnz %w[K], 1b\n"

"cmp %w[is_first_k], #1\n"
"beq 2f\n"
"beq 2f\n"
"cmp %x[m_remain], #8 \n"
"beq 8f \n"
"cmp %x[m_remain], #4 \n"
@@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"zip2 v15.2d, v30.2d, v31.2d \n"
"add v6.8h, v6.8h, v13.8h \n"
"add v7.8h, v7.8h, v15.8h \n"
//save to memory
// save to memory
"cmp %x[m_remain], #8 \n"
"beq 4f \n"
"cmp %x[m_remain], #4 \n"
@@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K,
"b 1000f \n"

"1000: \n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k),
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain)
:
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr),
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC),
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain),
[ n_remain ] "+r"(n_remain)
:
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31");
// clang-format on
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0",
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
// clang-format on

#undef LOAD_C_8
#undef STORE_C_8
}


static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x8(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off
#define LOAD_C_4 \
"ld1 {v0.8h}, [x0], #16\n" \
@@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
#undef LOAD_C_4
#undef STORE_C_4
}
static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) {
static void kern_4x8_remain(
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
K /= 8;
LDC = LDC * sizeof(int16_t);
const int8_t* a_ptr = packB;//packA;
const int8_t* b_ptr = packA;//packB;
// clang-format off
const int8_t* a_ptr = packB; // packA;
const int8_t* b_ptr = packA; // packB;
// clang-format off
register int16_t* outptr asm("x0") = output;
asm volatile(

@@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K,
#undef STORE_C_4
}


//! pack to icxoc
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7))
//! if M K is not times of 8,pack 0 instead
static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
const dt_int8* inptr, int ldin,
int m0, int mmax, int k0, int kmax) {
//! if M K is not times of 8,pack 0 instead
static void gemm_s8x8x16_mk4_8x8x8_pack_A(
dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 8;
@@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
for (; k_idx + 7 < kmax; k_idx += pack_k) {
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
}

if (k_idx < kmax) {
@@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
for (; k_idx + 7 < kmax; k_idx += pack_k) {
inptr1 = zerobuff;
interleave_8x8_mk4_b(inptr0,inptr1,outptr);
interleave_8x8_mk4_b(inptr0, inptr1, outptr);
}

if (k_idx < kmax) {
@@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr,
}
//! pack to nxic
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead.
static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
static void gemm_s8x8x16_mk4_8x8x8_pack_B(
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");

constexpr int pack_n = 8;
@@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
int8_t tmpbuff0[pack_n * pack_size] = {0};
int8_t tmpbuff1[pack_n * pack_size] = {0};
int8_t zerobuff[pack_n * pack_size] = {0};
const int ksize = round_up<int>((kmax - k0),8);
const int ksize = round_up<int>((kmax - k0), 8);
const int nsize = nmax - n0;
const int n_end = nsize / pack_n * pack_n + n0;
const int remain_n = nsize % pack_n;
int output_stride = ksize * pack_n;
int8_t* outptr_base = out;
int k_idx = k0;
for ( ; k_idx + 7 < kmax; k_idx += pack_k) {
for (; k_idx + 7 < kmax; k_idx += pack_k) {
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = inptr0 + ldin;
prefetch_3x(inptr0);
@@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
auto outptr = outptr_base;
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) {
transpose_8x8_mk4_b(inptr0, inptr1, outptr);
outptr += output_stride;
outptr += output_stride;
}
if (remain_n > 0) {
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size);
@@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
}
outptr_base += pack_n * pack_k;
}
if(k_idx < kmax){
if (k_idx < kmax) {
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr1 = nullptr;
prefetch_3x(inptr0);
@@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in,
}
}

} // namespace matmul_mk4_16x12x4_a53
} // namespace matmul_mk4_8x8x8
} // namespace aarch64
} // namespace megdnn



+ 128
- 142
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp View File

@@ -10,13 +10,13 @@
* implied.
*/

#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
@@ -28,39 +28,35 @@ using namespace aarch64::matmul;
// ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8);

void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0,
ymax, k0, kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(
out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0,
xmax, k0, kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(
out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8x8x16_8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k);
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_8x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
matmul_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K8;
}

for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
matmul_8x8x8::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_4x4==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4);

void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_4x4::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
void gemm_s8x8x16_4x4::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
(A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
void gemm_s8x8x16_4x4::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
@@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE, B_INTERLEAVE);
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
B_INTERLEAVE);
output += B_INTERLEAVE;
cur_packB += K4;
}

for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE,
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K4;
}
@@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
matmul_4x4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, A_INTERLEAVE),
std::min<size_t>(N - n, B_INTERLEAVE));
output += B_INTERLEAVE;
cur_packB += K4;
}
@@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_mk4_16x12==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53);

void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
void gemm_s8x8x16_mk4_16x12_a53::pack_A(
dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(
out, in, ldin, y0, ymax, k0, kmax);
}

void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
void gemm_s8x8x16_mk4_16x12_a53::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(
out, in, ldin, x0, xmax, k0, kmax);
}

void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
const dt_int8* packB, size_t M, size_t N,
size_t K, dt_int16* C, size_t LDC,
bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_16x12_a53::kern(
const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
@@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_16x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_16x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
@@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_8x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_8x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
@@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_16x12x4_a53::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_16x12x4_a53::kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
@@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
// ===========================gemm_s8x8x16_mk4_4x4_a72==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72);

void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax,
k0, kmax);
void gemm_s8x8x16_mk4_4x4_a72::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(
out, in, ldin, y0, ymax, k0, kmax);
}

void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax,
k0, kmax);
void gemm_s8x8x16_mk4_4x4_a72::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(
out, in, ldin, x0, xmax, k0, kmax);
}

void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k,
const dt_int16*, dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_4x4_a72::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
@@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,

const int8_t* cur_packB = packB;
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
matmul_mk4_4x4x8_a72::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * packed_k;
}
if (remain_n > 0) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
matmul_mk4_4x4x8_a72::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * packed_k;
}
@@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
// ===========================gemm_s8x8x16_mk4_8x8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8);

void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
void gemm_s8x8x16_mk4_8x8x8::pack_A(
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax);
}

void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
void gemm_s8x8x16_mk4_8x8x8::pack_B(
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
void gemm_s8x8x16_mk4_8x8x8::kern(
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K,
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
@@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, pack_n);
matmul_mk4_8x8x8::kern_8x8(
packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n);
output += pack_n * pack_size;
cur_packB += KSIZE8;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, pack_m, remain_n);
matmul_mk4_8x8x8::kern_8x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n);
output += remain_n * pack_size;
cur_packB += KSIZE8;
}
@@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC,
is_first_k, 4, pack_n);
matmul_mk4_8x8x8::kern_4x8(
packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC,
is_first_k, 4, remain_n);
matmul_mk4_8x8x8::kern_4x8_remain(
packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save