@@ -23,9 +23,9 @@ | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" | #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | ||||
#pragma GCC diagnostic ignored "-Wsign-compare" | #pragma GCC diagnostic ignored "-Wsign-compare" | ||||
#include <hip/hip_runtime_api.h> | |||||
#include <hip/hip_runtime.h> | |||||
#include <hip/hip_fp16.h> | #include <hip/hip_fp16.h> | ||||
#include <hip/hip_runtime.h> | |||||
#include <hip/hip_runtime_api.h> | |||||
#pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
#if !defined(__HIP_PLATFORM_HCC__) | #if !defined(__HIP_PLATFORM_HCC__) | ||||
@@ -11,10 +11,10 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/thin/function.h" | |||||
#include "megcore_cdefs.h" | |||||
#include <cstddef> | #include <cstddef> | ||||
#include <memory> | #include <memory> | ||||
#include "megcore_cdefs.h" | |||||
#include "megdnn/thin/function.h" | |||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
@@ -26,36 +26,35 @@ namespace megcore { | |||||
* the caller thread immediately. | * the caller thread immediately. | ||||
*/ | */ | ||||
class CPUDispatcher { | class CPUDispatcher { | ||||
public: | |||||
using Task = megdnn::thin_function<void()>; | |||||
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||||
virtual ~CPUDispatcher() noexcept; | |||||
/*! | |||||
* \brief dispatch a task on the computing thread | |||||
* \param task the task that would be moved away | |||||
*/ | |||||
virtual void dispatch(Task&& task) = 0; | |||||
/*! | |||||
* \brief dispatch a multithreading task on the computing thread | |||||
* \param task the task would be moved away | |||||
* \param parallelism the parallelism of the task. | |||||
*/ | |||||
virtual void dispatch(MultiThreadingTask&& task, | |||||
size_t parallelism) = 0; | |||||
/*! | |||||
* \brief synchronize the calling thread with the computing thread | |||||
*/ | |||||
virtual void sync() = 0; | |||||
/*! | |||||
* \brief the computing thread number. | |||||
*/ | |||||
virtual size_t nr_threads() = 0; | |||||
public: | |||||
using Task = megdnn::thin_function<void()>; | |||||
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||||
virtual ~CPUDispatcher() noexcept; | |||||
/*! | |||||
* \brief dispatch a task on the computing thread | |||||
* \param task the task that would be moved away | |||||
*/ | |||||
virtual void dispatch(Task&& task) = 0; | |||||
/*! | |||||
* \brief dispatch a multithreading task on the computing thread | |||||
* \param task the task would be moved away | |||||
* \param parallelism the parallelism of the task. | |||||
*/ | |||||
virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0; | |||||
/*! | |||||
* \brief synchronize the calling thread with the computing thread | |||||
*/ | |||||
virtual void sync() = 0; | |||||
/*! | |||||
* \brief the computing thread number. | |||||
*/ | |||||
virtual size_t nr_threads() = 0; | |||||
}; | }; | ||||
} // namespace megcore | |||||
} // namespace megcore | |||||
using MegcoreCPUDispatcher = megcore::CPUDispatcher; | using MegcoreCPUDispatcher = megcore::CPUDispatcher; | ||||
@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||||
* \brief Layer 1: device handle | * \brief Layer 1: device handle | ||||
*/ | */ | ||||
struct megcoreDeviceContext; | struct megcoreDeviceContext; | ||||
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; | |||||
typedef struct megcoreDeviceContext* megcoreDeviceHandle_t; | |||||
megcoreStatus_t megcoreCreateDeviceHandle( | megcoreStatus_t megcoreCreateDeviceHandle( | ||||
megcoreDeviceHandle_t *handle, | |||||
megcorePlatform_t platform, | |||||
int deviceID = -1, | |||||
megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1, | |||||
unsigned int flags = 0); | unsigned int flags = 0); | ||||
megcoreStatus_t megcoreDestroyDeviceHandle( | |||||
megcoreDeviceHandle_t handle); | |||||
megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, | |||||
megcorePlatform_t *platform); | |||||
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, | |||||
int *deviceID); | |||||
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, | |||||
size_t *memAlignmentInBytes); | |||||
megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle); | |||||
megcoreStatus_t megcoreGetPlatform( | |||||
megcoreDeviceHandle_t handle, megcorePlatform_t* platform); | |||||
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID); | |||||
megcoreStatus_t megcoreGetMemAlignment( | |||||
megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes); | |||||
megcoreStatus_t megcoreGetDeviceFlags( | megcoreStatus_t megcoreGetDeviceFlags( | ||||
megcoreDeviceHandle_t handle, | |||||
unsigned int *flags); | |||||
megcoreDeviceHandle_t handle, unsigned int* flags); | |||||
megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | ||||
megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | ||||
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, | |||||
void **devPtr, size_t sizeInBytes); | |||||
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, | |||||
void *devPtr); | |||||
megcoreStatus_t megcoreMalloc( | |||||
megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes); | |||||
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr); | |||||
/** | /** | ||||
* \brief Layer 2: computing handle | * \brief Layer 2: computing handle | ||||
*/ | */ | ||||
struct megcoreComputingContext; | struct megcoreComputingContext; | ||||
typedef struct megcoreComputingContext *megcoreComputingHandle_t; | |||||
typedef struct megcoreComputingContext* megcoreComputingHandle_t; | |||||
megcoreStatus_t megcoreCreateComputingHandle( | megcoreStatus_t megcoreCreateComputingHandle( | ||||
megcoreComputingHandle_t *compHandle, | |||||
megcoreDeviceHandle_t devHandle, | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags = 0); | unsigned int flags = 0); | ||||
megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | ||||
megcoreComputingHandle_t *compHandle, | |||||
megcoreDeviceHandle_t devHandle, | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | ||||
unsigned int flags = 0); | unsigned int flags = 0); | ||||
megcoreStatus_t megcoreDestroyComputingHandle( | |||||
megcoreComputingHandle_t handle); | |||||
megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle); | |||||
megcoreStatus_t megcoreGetDeviceHandle( | megcoreStatus_t megcoreGetDeviceHandle( | ||||
megcoreComputingHandle_t compHandle, | |||||
megcoreDeviceHandle_t *devHandle); | |||||
megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle); | |||||
megcoreStatus_t megcoreGetComputingFlags( | megcoreStatus_t megcoreGetComputingFlags( | ||||
megcoreComputingHandle_t handle, | |||||
unsigned int *flags); | |||||
megcoreComputingHandle_t handle, unsigned int* flags); | |||||
MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | ||||
megcoreStatus_t megcoreMemcpy( | megcoreStatus_t megcoreMemcpy( | ||||
megcoreComputingHandle_t handle, | |||||
void *dst, const void *src, size_t sizeInBytes, | |||||
megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, | |||||
megcoreMemcpyKind_t kind); | megcoreMemcpyKind_t kind); | ||||
megcoreStatus_t megcoreMemset( | megcoreStatus_t megcoreMemset( | ||||
megcoreComputingHandle_t handle, | |||||
void *dst, int value, size_t sizeInBytes); | |||||
megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes); | |||||
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | ||||
/** | /** | ||||
* \brief Miscellaneous | * \brief Miscellaneous | ||||
*/ | */ | ||||
const char *megcoreGetErrorName(megcoreStatus_t status); | |||||
const char* megcoreGetErrorName(megcoreStatus_t status); | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
unsigned int flags, const AtlasContext& ctx); | unsigned int flags, const AtlasContext& ctx); | ||||
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, | |||||
AtlasContext* ctx); | |||||
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx); | |||||
namespace atlas { | namespace atlas { | ||||
//! convert acl error code to error string | //! convert acl error code to error string | ||||
@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
unsigned int flags, aclrtStream stream) { | unsigned int flags, aclrtStream stream) { | ||||
megcore::AtlasContext ctx{stream}; | megcore::AtlasContext ctx{stream}; | ||||
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, | |||||
flags, ctx); | |||||
return megcore::createComputingHandleWithAtlasContext( | |||||
compHandle, devHandle, flags, ctx); | |||||
} | } | ||||
inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, | |||||
aclrtStream* stream) { | |||||
inline megcoreStatus_t megcoreGetACLStream( | |||||
megcoreComputingHandle_t handle, aclrtStream* stream) { | |||||
megcore::AtlasContext ctx; | megcore::AtlasContext ctx; | ||||
auto ret = megcore::getAtlasContext(handle, &ctx); | auto ret = megcore::getAtlasContext(handle, &ctx); | ||||
*stream = ctx.stream; | *stream = ctx.stream; | ||||
@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
unsigned int flags, const CambriconContext& ctx); | unsigned int flags, const CambriconContext& ctx); | ||||
megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, | |||||
CambriconContext* ctx); | |||||
megcoreStatus_t getCambriconContext( | |||||
megcoreComputingHandle_t handle, CambriconContext* ctx); | |||||
} // namespace megcore | } // namespace megcore | ||||
@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -40,7 +40,6 @@ typedef enum { | |||||
megcoreErrorInternalError = 5, | megcoreErrorInternalError = 5, | ||||
} megcoreStatus_t; | } megcoreStatus_t; | ||||
/** | /** | ||||
* \brief Memcpy kind | * \brief Memcpy kind | ||||
*/ | */ | ||||
@@ -70,6 +69,6 @@ struct AsyncErrorInfo { | |||||
char msg[228]; | char msg[228]; | ||||
int msg_args[4]; | int msg_args[4]; | ||||
}; | }; | ||||
} // namespace megcore | |||||
} // namespace megcore | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
unsigned int flags, const CudaContext& ctx); | unsigned int flags, const CudaContext& ctx); | ||||
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, | |||||
CudaContext* ctx); | |||||
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx); | |||||
} // namespace megcore | } // namespace megcore | ||||
@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( | |||||
unsigned int flags, cudaStream_t stream) { | unsigned int flags, cudaStream_t stream) { | ||||
megcore::CudaContext ctx; | megcore::CudaContext ctx; | ||||
ctx.stream = stream; | ctx.stream = stream; | ||||
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, | |||||
flags, ctx); | |||||
return megcore::createComputingHandleWithCUDAContext( | |||||
compHandle, devHandle, flags, ctx); | |||||
} | } | ||||
static inline megcoreStatus_t megcoreGetCUDAStream( | static inline megcoreStatus_t megcoreGetCUDAStream( | ||||
@@ -23,7 +23,9 @@ struct ROCMContext { | |||||
hipStream_t stream = nullptr; | hipStream_t stream = nullptr; | ||||
static std::atomic_bool sm_miopen_algo_search; | static std::atomic_bool sm_miopen_algo_search; | ||||
static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } | |||||
static inline bool enable_miopen_algo_search() { | |||||
return sm_miopen_algo_search.load(); | |||||
} | |||||
static inline void enable_miopen_algo_search(bool enable_algo_search) { | static inline void enable_miopen_algo_search(bool enable_algo_search) { | ||||
sm_miopen_algo_search.store(enable_algo_search); | sm_miopen_algo_search.store(enable_algo_search); | ||||
} | } | ||||
@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
unsigned int flags, const ROCMContext& ctx); | unsigned int flags, const ROCMContext& ctx); | ||||
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, | |||||
ROCMContext* ctx); | |||||
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx); | |||||
// Set MIOpen algo search enabled or disabled | // Set MIOpen algo search enabled or disabled | ||||
megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | ||||
@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( | |||||
unsigned int flags, hipStream_t stream) { | unsigned int flags, hipStream_t stream) { | ||||
megcore::ROCMContext ctx; | megcore::ROCMContext ctx; | ||||
ctx.stream = stream; | ctx.stream = stream; | ||||
return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, | |||||
flags, ctx); | |||||
return megcore::createComputingHandleWithROCMContext( | |||||
compHandle, devHandle, flags, ctx); | |||||
} | } | ||||
static inline megcoreStatus_t megcoreGetROCMStream( | static inline megcoreStatus_t megcoreGetROCMStream( | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/version.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "megdnn/version.h" | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -14,20 +14,20 @@ | |||||
#include "megdnn/config/config.h" | #include "megdnn/config/config.h" | ||||
#if defined(__GNUC__) || defined(__clang__) | #if defined(__GNUC__) || defined(__clang__) | ||||
#if !defined (__clang__) | |||||
// gcc specific | |||||
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||||
#if GCC_VERSION < 40800 | |||||
#error "GCC version should be at least 4.8.0." | |||||
#endif // GCC_VERSION < 40800 | |||||
#endif // !defined(__clang__) | |||||
#if !defined(__clang__) | |||||
// gcc specific | |||||
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||||
#if GCC_VERSION < 40800 | |||||
#error "GCC version should be at least 4.8.0." | |||||
#endif // GCC_VERSION < 40800 | |||||
#endif // !defined(__clang__) | |||||
#ifndef megdnn_trap | |||||
#define megdnn_trap() __builtin_trap() | |||||
#endif | |||||
#ifndef megdnn_trap | |||||
#define megdnn_trap() __builtin_trap() | |||||
#endif | |||||
#define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||||
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||||
#define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||||
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||||
#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | ||||
//! Thumb2 limit code length | //! Thumb2 limit code length | ||||
@@ -36,123 +36,122 @@ | |||||
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | ||||
#endif | #endif | ||||
#define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||||
#define MEGDNN_PACKED __attribute__((packed)) | |||||
#define MEGDNN_CONSTEXPR constexpr | |||||
#define MEGDNN_NOEXCEPT noexcept | |||||
#define MEGDNN_STATIC_ASSERT static_assert | |||||
#define MEGDNN_FINAL final | |||||
#define MEGDNN_NORETURN __attribute__((noreturn)) | |||||
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||||
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
#if defined(__clang_major__) && (__clang_major__ >= 7) | |||||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
#else | |||||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||||
#endif | |||||
#define MEGDNN_NOINLINE __attribute__((noinline)) | |||||
#define megdnn_isatty(x) isatty(x) | |||||
#define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||||
#define MEGDNN_PACKED __attribute__((packed)) | |||||
#define MEGDNN_CONSTEXPR constexpr | |||||
#define MEGDNN_NOEXCEPT noexcept | |||||
#define MEGDNN_STATIC_ASSERT static_assert | |||||
#define MEGDNN_FINAL final | |||||
#define MEGDNN_NORETURN __attribute__((noreturn)) | |||||
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||||
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
#if defined(__clang_major__) && (__clang_major__ >= 7) | |||||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
#else | |||||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||||
#endif | |||||
#define MEGDNN_NOINLINE __attribute__((noinline)) | |||||
#define megdnn_isatty(x) isatty(x) | |||||
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | ||||
#ifndef megdnn_trap | #ifndef megdnn_trap | ||||
#define megdnn_trap() __debugbreak() | #define megdnn_trap() __debugbreak() | ||||
#endif | #endif | ||||
#define megdnn_likely(v) (bool(v)) | |||||
#define megdnn_likely(v) (bool(v)) | |||||
#define megdnn_unlikely(v) (bool(v)) | #define megdnn_unlikely(v) (bool(v)) | ||||
#define MEGDNN_DEPRECATED | #define MEGDNN_DEPRECATED | ||||
#define MEGDNN_PACKED | #define MEGDNN_PACKED | ||||
#define MEGDNN_CONSTEXPR constexpr | |||||
#define MEGDNN_NOEXCEPT noexcept | |||||
#define MEGDNN_CONSTEXPR constexpr | |||||
#define MEGDNN_NOEXCEPT noexcept | |||||
#define MEGDNN_STATIC_ASSERT static_assert | #define MEGDNN_STATIC_ASSERT static_assert | ||||
#define MEGDNN_FINAL final | |||||
#define MEGDNN_FINAL final | |||||
#if defined(_MSC_VER) | #if defined(_MSC_VER) | ||||
#define MEGDNN_NORETURN __declspec(noreturn) | |||||
#define MEGDNN_NOINLINE __declspec(noinline) | |||||
#define MEGDNN_NORETURN __declspec(noreturn) | |||||
#define MEGDNN_NOINLINE __declspec(noinline) | |||||
#else | #else | ||||
#define MEGDNN_NORETURN | |||||
#define MEGDNN_FORCE_NOINLINE | |||||
#endif // _MSC_VER | |||||
#define MEGDNN_NORETURN | |||||
#define MEGDNN_FORCE_NOINLINE | |||||
#endif // _MSC_VER | |||||
#define MEGDNN_WARN_UNUSED_RESULT | #define MEGDNN_WARN_UNUSED_RESULT | ||||
#define megdnn_isatty(x) _isatty(x) | #define megdnn_isatty(x) _isatty(x) | ||||
#else | #else | ||||
#error "unknown compiler" | |||||
#endif // __GNUC__ | |||||
#error "unknown compiler" | |||||
#endif // __GNUC__ | |||||
// __cpp_exceptions and __cpp_rtti is referred from | // __cpp_exceptions and __cpp_rtti is referred from | ||||
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | ||||
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||||
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||||
// similar for __GXX_RTTI | // similar for __GXX_RTTI | ||||
// _CPPUNWIND and _CPPRTTI is used by MSVC, see | // _CPPUNWIND and _CPPRTTI is used by MSVC, see | ||||
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | ||||
#ifndef MEGDNN_ENABLE_EXCEPTIONS | #ifndef MEGDNN_ENABLE_EXCEPTIONS | ||||
#if __cpp_exceptions || __EXCEPTIONS || \ | |||||
(defined(_MSC_VER) && defined(_CPPUNWIND)) | |||||
#define MEGDNN_ENABLE_EXCEPTIONS 1 | |||||
#else | |||||
#define MEGDNN_ENABLE_EXCEPTIONS 0 | |||||
#endif | |||||
#if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||||
#define MEGDNN_ENABLE_EXCEPTIONS 1 | |||||
#else | |||||
#define MEGDNN_ENABLE_EXCEPTIONS 0 | |||||
#endif | |||||
#endif | #endif | ||||
#ifndef MEGDNN_ENABLE_RTTI | #ifndef MEGDNN_ENABLE_RTTI | ||||
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||||
#define MEGDNN_ENABLE_RTTI 1 | |||||
#else | |||||
#define MEGDNN_ENABLE_RTTI 0 | |||||
#endif | |||||
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||||
#define MEGDNN_ENABLE_RTTI 1 | |||||
#else | |||||
#define MEGDNN_ENABLE_RTTI 0 | |||||
#endif | |||||
#endif | #endif | ||||
#ifdef __CUDACC__ | #ifdef __CUDACC__ | ||||
#define MEGDNN_CC_CUDA 1 | |||||
#undef MEGDNN_CONSTEXPR | |||||
#define MEGDNN_CONSTEXPR const | |||||
#define MEGDNN_CC_CUDA 1 | |||||
#undef MEGDNN_CONSTEXPR | |||||
#define MEGDNN_CONSTEXPR const | |||||
#if defined(__CUDACC_VER_MAJOR__) | #if defined(__CUDACC_VER_MAJOR__) | ||||
#if __CUDACC_VER_MAJOR__ >= 9 | #if __CUDACC_VER_MAJOR__ >= 9 | ||||
#undef MEGDNN_STATIC_ASSERT | |||||
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||||
#undef MEGDNN_STATIC_ASSERT | |||||
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||||
#else | #else | ||||
#undef MEGDNN_STATIC_ASSERT | |||||
#define MEGDNN_STATIC_ASSERT(cond, msg) | |||||
#undef MEGDNN_STATIC_ASSERT | |||||
#define MEGDNN_STATIC_ASSERT(cond, msg) | |||||
#endif | #endif | ||||
#endif | #endif | ||||
#define nullptr NULL | |||||
#undef MEGDNN_FINAL | |||||
#define MEGDNN_FINAL | |||||
#define nullptr NULL | |||||
#undef MEGDNN_FINAL | |||||
#define MEGDNN_FINAL | |||||
#elif defined(__HIPCC__) | #elif defined(__HIPCC__) | ||||
#define MEGDNN_CC_CUDA 1 | |||||
#define MEGDNN_CC_CUDA 1 | |||||
#else | #else | ||||
#define MEGDNN_CC_HOST 1 | |||||
#endif // __CUDACC__ | |||||
#define MEGDNN_CC_HOST 1 | |||||
#endif // __CUDACC__ | |||||
// MEGDNN_HOST and MEGDNN_DEVICE | // MEGDNN_HOST and MEGDNN_DEVICE | ||||
#if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
#define MEGDNN_HOST __host__ | |||||
#define MEGDNN_DEVICE __device__ | |||||
#define MEGDNN_HOST __host__ | |||||
#define MEGDNN_DEVICE __device__ | |||||
#else | #else | ||||
#define MEGDNN_HOST | |||||
#define MEGDNN_DEVICE | |||||
#define MEGDNN_HOST | |||||
#define MEGDNN_DEVICE | |||||
#endif | #endif | ||||
#if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
#define MEGDNN_FORCE_INLINE __forceinline__ | |||||
#define MEGDNN_FORCE_INLINE __forceinline__ | |||||
#else | #else | ||||
#if __GNUC__ || __has_attribute(always_inline) | #if __GNUC__ || __has_attribute(always_inline) | ||||
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||||
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||||
#else | #else | ||||
#define MEGDNN_FORCE_INLINE inline | |||||
#define MEGDNN_FORCE_INLINE inline | |||||
#endif | #endif | ||||
#endif | #endif | ||||
#if defined(_MSC_VER) || defined(WIN32) | #if defined(_MSC_VER) || defined(WIN32) | ||||
#define ATTR_ALIGNED(v) __declspec(align(v)) | |||||
#define ATTR_ALIGNED(v) __declspec(align(v)) | |||||
#else | #else | ||||
#define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||||
#define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -16,10 +16,10 @@ | |||||
#include "megdnn/internal/defs.h" | #include "megdnn/internal/defs.h" | ||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
#include <cstdarg> | |||||
#include <string> | #include <string> | ||||
#include <type_traits> | #include <type_traits> | ||||
#include <vector> | #include <vector> | ||||
#include <cstdarg> | |||||
#include "megdnn/thin/small_vector.h" | #include "megdnn/thin/small_vector.h" | ||||
#endif // MEGDNN_CC_HOST | #endif // MEGDNN_CC_HOST | ||||
@@ -35,8 +35,7 @@ class ErrorHandler { | |||||
protected: | protected: | ||||
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | ||||
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( | |||||
const std::string& msg) { | |||||
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) { | |||||
on_megdnn_error(msg); | on_megdnn_error(msg); | ||||
} | } | ||||
@@ -70,8 +69,9 @@ public: | |||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | ||||
typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, | |||||
int line, const char* fmt, va_list ap); | |||||
typedef void (*LogHandler)( | |||||
LogLevel level, const char* file, const char* func, int line, const char* fmt, | |||||
va_list ap); | |||||
/*! | /*! | ||||
* \brief set the callback to receive all log messages | * \brief set the callback to receive all log messages | ||||
@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { | |||||
ptrdiff_t low_elem, low_byte; | ptrdiff_t low_elem, low_byte; | ||||
size_t high_elem, high_byte; | size_t high_elem, high_byte; | ||||
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, | |||||
size_t high_byte) | |||||
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte) | |||||
: low_elem(low_elem), | : low_elem(low_elem), | ||||
low_byte(low_byte), | low_byte(low_byte), | ||||
high_elem(high_elem), | high_elem(high_elem), | ||||
@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { | |||||
TensorLayout(const TensorShape& shape, DType dtype, Format format); | TensorLayout(const TensorShape& shape, DType dtype, Format format); | ||||
//! creating layout with user-specified shape and stride. | //! creating layout with user-specified shape and stride. | ||||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
DType dtype); | |||||
TensorLayout( | |||||
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
DType dtype); | |||||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
DType dtype, Format format); | |||||
TensorLayout( | |||||
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype, | |||||
Format format); | |||||
/* =================== inplace modifiers =================== */ | /* =================== inplace modifiers =================== */ | ||||
@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { | |||||
* | * | ||||
* \throw TensorReshapeError if no stride exists for target shape. | * \throw TensorReshapeError if no stride exists for target shape. | ||||
*/ | */ | ||||
TensorLayout reshape(const TensorShape& shape) const | |||||
MEGDNN_WARN_UNUSED_RESULT; | |||||
TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
/*! | /*! | ||||
* \brief try to reshape to another view; return whether these two shapes | * \brief try to reshape to another view; return whether these two shapes | ||||
@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { | |||||
* \return true iff there exists target stride so this layout can be | * \return true iff there exists target stride so this layout can be | ||||
* converted to target shape and the elements can match. | * converted to target shape and the elements can match. | ||||
*/ | */ | ||||
bool try_reshape(TensorLayout& output, | |||||
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
bool try_reshape(TensorLayout& output, const TensorShape& shape) const | |||||
MEGDNN_WARN_UNUSED_RESULT; | |||||
/*! | /*! | ||||
* \brief Broadcast on dims with shape == 1 to match target *shape*. | * \brief Broadcast on dims with shape == 1 to match target *shape*. | ||||
* \throw TensorReshapeError if could not be satisfied | * \throw TensorReshapeError if could not be satisfied | ||||
*/ | */ | ||||
TensorLayout broadcast(const TensorShape& shape) const | |||||
MEGDNN_WARN_UNUSED_RESULT; | |||||
TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
/*! | /*! | ||||
* \brief Collapse consecutive axes with contiguous layout together | * \brief Collapse consecutive axes with contiguous layout together | ||||
@@ -441,8 +440,7 @@ struct Workspace { | |||||
Workspace() : raw_ptr(NULL), size(0) {} | Workspace() : raw_ptr(NULL), size(0) {} | ||||
Workspace(dt_byte* raw_ptr_, size_t size_) | |||||
: raw_ptr(raw_ptr_), size(size_) {} | |||||
Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {} | |||||
template <typename T> | template <typename T> | ||||
T* ptr(size_t offset_in_bytes = 0) const { | T* ptr(size_t offset_in_bytes = 0) const { | ||||
@@ -467,9 +465,8 @@ public: | |||||
* \param shape requested output shape | * \param shape requested output shape | ||||
* \param user_data extra user data passed in DynOutMallocPolicyCall | * \param user_data extra user data passed in DynOutMallocPolicyCall | ||||
*/ | */ | ||||
virtual TensorND alloc_output(size_t id, DType dtype, | |||||
const TensorShape& shape, | |||||
void* user_data) = 0; | |||||
virtual TensorND alloc_output( | |||||
size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0; | |||||
/*! | /*! | ||||
* \brief allocate workspace memory | * \brief allocate workspace memory | ||||
@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { | |||||
*/ | */ | ||||
template <typename T = void, typename elem = T> | template <typename T = void, typename elem = T> | ||||
T* alloc_workspace(size_t nr_elem) { | T* alloc_workspace(size_t nr_elem) { | ||||
using real_elem = | |||||
typename std::conditional<std::is_same<elem, void>::value, | |||||
uint8_t, elem>::type; | |||||
return static_cast<T*>(policy->alloc_workspace( | |||||
nr_elem * sizeof(real_elem), user_data)); | |||||
using real_elem = typename std::conditional< | |||||
std::is_same<elem, void>::value, uint8_t, elem>::type; | |||||
return static_cast<T*>( | |||||
policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data)); | |||||
} | } | ||||
void free_workspace(void* ptr) { | |||||
return policy->free_workspace(ptr, user_data); | |||||
} | |||||
void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); } | |||||
}; | }; | ||||
template <typename T> | template <typename T> | ||||
class EnumClassBit { | class EnumClassBit { | ||||
std::underlying_type_t<T> m_val; | std::underlying_type_t<T> m_val; | ||||
@@ -528,8 +521,7 @@ class EnumClassBit { | |||||
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | ||||
public: | public: | ||||
constexpr EnumClassBit(T v) | |||||
: m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
constexpr operator T() const { return static_cast<T>(m_val); } | constexpr operator T() const { return static_cast<T>(m_val); } | ||||
@@ -542,7 +534,7 @@ public: | |||||
DEF_OPR(&) | DEF_OPR(&) | ||||
DEF_OPR(|) | DEF_OPR(|) | ||||
DEF_OPR (^) | |||||
DEF_OPR(^) | |||||
constexpr EnumClassBit operator~() const { return ~m_val; } | constexpr EnumClassBit operator~() const { return ~m_val; } | ||||
@@ -553,14 +545,13 @@ public: | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
return ::megdnn::EnumClassBit<cls>(x) \ | |||||
op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} | } | ||||
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | ||||
@@ -14,14 +14,14 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_ENABLE_GETENV | #if MGB_ENABLE_GETENV | ||||
#define MGB_GETENV ::std::getenv | |||||
#define MGB_GETENV ::std::getenv | |||||
#else | #else | ||||
#define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||||
#define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||||
#endif | #endif | ||||
#ifdef WIN32 | #ifdef WIN32 | ||||
#define unsetenv(_name) _putenv_s(_name, ""); | |||||
#define setenv(name,value,overwrite) _putenv_s(name,value) | |||||
#define unsetenv(_name) _putenv_s(_name, ""); | |||||
#define setenv(name, value, overwrite) _putenv_s(name, value) | |||||
#endif | #endif | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -32,8 +32,7 @@ namespace megdnn { | |||||
*/ | */ | ||||
template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
bool has_available_algo(Opr* opr, Args&&... args) { | bool has_available_algo(Opr* opr, Args&&... args) { | ||||
const typename Opr::AlgoBase::SizeArgs size_args( | |||||
opr, std::forward<Args>(args)...); | |||||
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||||
for (auto i : Opr::algo_pack().all_algos) { | for (auto i : Opr::algo_pack().all_algos) { | ||||
if (i->is_available(size_args)) { | if (i->is_available(size_args)) { | ||||
return true; | return true; | ||||
@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { | |||||
return false; | return false; | ||||
} | } | ||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -17,11 +17,11 @@ | |||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, | |||||
int device_id = -1); | |||||
cudaStream_t get_cuda_stream(Handle *handle); | |||||
std::unique_ptr<Handle> make_cuda_handle_with_stream( | |||||
cudaStream_t stream, int device_id = -1); | |||||
cudaStream_t get_cuda_stream(Handle* handle); | |||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -3,17 +3,22 @@ | |||||
* | * | ||||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | ||||
* | * | ||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
* Software is furnished to do so, subject to the following conditions: | |||||
* | |||||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
* | |||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||||
* software and associated documentation files (the "Software"), to deal in the Software | |||||
* without restriction, including without limitation the rights to use, copy, modify, | |||||
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||||
* permit persons to whom the Software is furnished to do so, subject to the following | |||||
* conditions: | |||||
* | |||||
* The above copyright notice and this permission notice shall be included in all copies | |||||
* or substantial portions of the Software. | |||||
* | |||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||||
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||||
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||||
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||||
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* | * | ||||
* Version 1.11.0 | * Version 1.11.0 | ||||
* \file | * \file | ||||
@@ -41,8 +46,8 @@ | |||||
#undef HALF_NOEXCEPT | #undef HALF_NOEXCEPT | ||||
#undef HALF_NOTHROW | #undef HALF_NOTHROW | ||||
#ifdef HALF_POP_WARNINGS | #ifdef HALF_POP_WARNINGS | ||||
#pragma warning(pop) | |||||
#undef HALF_POP_WARNINGS | |||||
#pragma warning(pop) | |||||
#undef HALF_POP_WARNINGS | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -3,17 +3,22 @@ | |||||
* | * | ||||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | ||||
* | * | ||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
* Software is furnished to do so, subject to the following conditions: | |||||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||||
* software and associated documentation files (the "Software"), to deal in the Software | |||||
* without restriction, including without limitation the rights to use, copy, modify, | |||||
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||||
* permit persons to whom the Software is furnished to do so, subject to the following | |||||
* conditions: | |||||
* | * | ||||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
* The above copyright notice and this permission notice shall be included in all copies | |||||
* or substantial portions of the Software. | |||||
* | * | ||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||||
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||||
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||||
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||||
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
* | * | ||||
* Version 1.11.0 | * Version 1.11.0 | ||||
* \file | * \file | ||||
@@ -39,166 +44,164 @@ | |||||
#include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
/// Combined gcc version number. | /// Combined gcc version number. | ||||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||||
#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) | |||||
//check C++11 language features | |||||
#if defined(__clang__) //clang | |||||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif*/ | |||||
#elif defined(__GNUC__) //gcc | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(_MSC_VER) //Visual C++ | |||||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#define HALF_POP_WARNINGS 1 | |||||
#pragma warning(push) | |||||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||||
// check C++11 language features | |||||
#if defined(__clang__) // clang | |||||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ | |||||
!defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
/*#elif defined(__INTEL_COMPILER) | |||||
//Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= | |||||
1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define | |||||
HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 && | |||||
!defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define | |||||
HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 && | |||||
!defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define | |||||
HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/ | |||||
#elif defined(__GNUC__) // gcc | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
#endif | |||||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(_MSC_VER) // Visual C++ | |||||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
#endif | |||||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
#endif | |||||
#define HALF_POP_WARNINGS 1 | |||||
#pragma warning(push) | |||||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||||
#pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if, | |||||
// negative unsigned | |||||
#endif | #endif | ||||
//check C++11 library features | |||||
// check C++11 library features | |||||
#include <utility> | #include <utility> | ||||
#if defined(_LIBCPP_VERSION) //libc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(__GLIBCXX__) //libstdc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifdef __clang__ | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#else | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||||
#if _CPPLIB_VER >= 520 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#if _CPPLIB_VER >= 610 | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#endif | |||||
#if defined(_LIBCPP_VERSION) // libc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#elif defined(__GLIBCXX__) // libstdc++ | |||||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
#ifdef __clang__ | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#else | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#endif | |||||
#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ | |||||
#if _CPPLIB_VER >= 520 | |||||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
#endif | |||||
#ifndef HALF_ENABLE_CPP11_HASH | |||||
#define HALF_ENABLE_CPP11_HASH 1 | |||||
#endif | |||||
#endif | |||||
#if _CPPLIB_VER >= 610 | |||||
#ifndef HALF_ENABLE_CPP11_CMATH | |||||
#define HALF_ENABLE_CPP11_CMATH 1 | |||||
#endif | |||||
#endif | |||||
#endif | #endif | ||||
#undef HALF_GNUC_VERSION | #undef HALF_GNUC_VERSION | ||||
//support constexpr | |||||
// support constexpr | |||||
#if HALF_ENABLE_CPP11_CONSTEXPR | #if HALF_ENABLE_CPP11_CONSTEXPR | ||||
#define HALF_CONSTEXPR constexpr | |||||
#define HALF_CONSTEXPR_CONST constexpr | |||||
#define HALF_CONSTEXPR constexpr | |||||
#define HALF_CONSTEXPR_CONST constexpr | |||||
#else | #else | ||||
#define HALF_CONSTEXPR | |||||
#define HALF_CONSTEXPR_CONST const | |||||
#define HALF_CONSTEXPR | |||||
#define HALF_CONSTEXPR_CONST const | |||||
#endif | #endif | ||||
//support noexcept | |||||
// support noexcept | |||||
#if HALF_ENABLE_CPP11_NOEXCEPT | #if HALF_ENABLE_CPP11_NOEXCEPT | ||||
#define HALF_NOEXCEPT noexcept | |||||
#define HALF_NOTHROW noexcept | |||||
#define HALF_NOEXCEPT noexcept | |||||
#define HALF_NOTHROW noexcept | |||||
#else | #else | ||||
#define HALF_NOEXCEPT | |||||
#define HALF_NOTHROW throw() | |||||
#define HALF_NOEXCEPT | |||||
#define HALF_NOTHROW throw() | |||||
#endif | #endif | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <limits> | |||||
#include <climits> | #include <climits> | ||||
#include <cmath> | #include <cmath> | ||||
#include <cstring> | #include <cstring> | ||||
#include <ostream> | |||||
#include <istream> | #include <istream> | ||||
#include <limits> | |||||
#include <ostream> | |||||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | #if HALF_ENABLE_CPP11_TYPE_TRAITS | ||||
#include <type_traits> | |||||
#include <type_traits> | |||||
#endif | #endif | ||||
#if HALF_ENABLE_CPP11_CSTDINT | #if HALF_ENABLE_CPP11_CSTDINT | ||||
#include <cstdint> | |||||
#include <cstdint> | |||||
#endif | #endif | ||||
#if HALF_ENABLE_CPP11_HASH | #if HALF_ENABLE_CPP11_HASH | ||||
#include <functional> | |||||
#include <functional> | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -12,8 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#include "megcore.h" | #include "megcore.h" | ||||
#include "megdnn/config/config.h" | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/config/config.h" | |||||
#include <functional> | #include <functional> | ||||
#include <memory> | #include <memory> | ||||
@@ -24,150 +24,147 @@ namespace megdnn { | |||||
class OperatorBase; | class OperatorBase; | ||||
class Handle { | class Handle { | ||||
public: | |||||
enum class HandleType { | |||||
NAIVE = 0, | |||||
FALLBACK = 1, | |||||
X86 = 2, | |||||
ARM_COMMON = 3, | |||||
ARMV7 = 4, | |||||
AARCH64 = 5, | |||||
CUDA = 6, | |||||
ROCM = 11, | |||||
ATLAS = 13, | |||||
CAMBRICON = 12, | |||||
}; | |||||
//! Device vendor | |||||
enum class HandleVendorType : uint32_t { | |||||
NOT_SPEC = 0, | |||||
MALI = 1, | |||||
ADRENO = 2, | |||||
CUDA = 3, | |||||
INTEL = 4, | |||||
POWERVR = 5, | |||||
AMD = 6, | |||||
}; | |||||
protected: | |||||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||||
public: | |||||
/** | |||||
* \brief Create a MegDNN handle from a MegCore Computing handle. | |||||
* | |||||
* \param[in] computing_handle MegCore computing handle. Please note | |||||
* that computing_handle would not be released when this Handle is | |||||
* destructed | |||||
* \param[in] debug_level | |||||
* Applicable for CPU computing handle. | |||||
* 0 means taking the fastest possible code path; it may contains | |||||
* platform-specific instructions such as SSE for x86_64 or NEON for | |||||
* armv7v7. | |||||
* 1 means taking the fastest possible code path without | |||||
* platform-specific instructions in C++ code. Note that the compiled | |||||
* binary file still contains platform-specific codes. | |||||
* 2 means taking the naive code path. Performance is severely | |||||
* hampered, but it is less error-prone since the internal | |||||
* implementation is rather straightforward. | |||||
* | |||||
* **Debug level 1 and 2 should not be used in productions.** | |||||
*/ | |||||
static std::unique_ptr<Handle> make( | |||||
megcoreComputingHandle_t computing_handle, | |||||
int debug_level = 0); | |||||
public: | |||||
enum class HandleType { | |||||
NAIVE = 0, | |||||
FALLBACK = 1, | |||||
X86 = 2, | |||||
ARM_COMMON = 3, | |||||
ARMV7 = 4, | |||||
AARCH64 = 5, | |||||
CUDA = 6, | |||||
ROCM = 11, | |||||
ATLAS = 13, | |||||
CAMBRICON = 12, | |||||
}; | |||||
//! Device vendor | |||||
enum class HandleVendorType : uint32_t { | |||||
NOT_SPEC = 0, | |||||
MALI = 1, | |||||
ADRENO = 2, | |||||
CUDA = 3, | |||||
INTEL = 4, | |||||
POWERVR = 5, | |||||
AMD = 6, | |||||
}; | |||||
protected: | |||||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||||
public: | |||||
/** | |||||
* \brief Create a MegDNN handle from a MegCore Computing handle. | |||||
* | |||||
* \param[in] computing_handle MegCore computing handle. Please note | |||||
* that computing_handle would not be released when this Handle is | |||||
* destructed | |||||
* \param[in] debug_level | |||||
* Applicable for CPU computing handle. | |||||
* 0 means taking the fastest possible code path; it may contains | |||||
* platform-specific instructions such as SSE for x86_64 or NEON for | |||||
* armv7v7. | |||||
* 1 means taking the fastest possible code path without | |||||
* platform-specific instructions in C++ code. Note that the compiled | |||||
* binary file still contains platform-specific codes. | |||||
* 2 means taking the naive code path. Performance is severely | |||||
* hampered, but it is less error-prone since the internal | |||||
* implementation is rather straightforward. | |||||
* | |||||
* **Debug level 1 and 2 should not be used in productions.** | |||||
*/ | |||||
static std::unique_ptr<Handle> make( | |||||
megcoreComputingHandle_t computing_handle, int debug_level = 0); | |||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
static std::unique_ptr<Handle> make_cuda_handle( | |||||
megcoreComputingHandle_t computing_handle); | |||||
template <typename opr> | |||||
std::unique_ptr<opr> create_cuda_operator(); | |||||
static std::unique_ptr<Handle> make_cuda_handle( | |||||
megcoreComputingHandle_t computing_handle); | |||||
template <typename opr> | |||||
std::unique_ptr<opr> create_cuda_operator(); | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
static std::unique_ptr<Handle> make_rocm_handle( | |||||
megcoreComputingHandle_t computing_handle); | |||||
template <typename opr> | |||||
std::unique_ptr<opr> create_rocm_operator(); | |||||
static std::unique_ptr<Handle> make_rocm_handle( | |||||
megcoreComputingHandle_t computing_handle); | |||||
template <typename opr> | |||||
std::unique_ptr<opr> create_rocm_operator(); | |||||
#endif | #endif | ||||
virtual ~Handle(); | |||||
/*! | |||||
* \brief Get the underlying megcore computing handle. | |||||
*/ | |||||
megcoreComputingHandle_t megcore_computing_handle() const { | |||||
return m_computing_handle; | |||||
} | |||||
/*! | |||||
* \brief set a callback function to be invoked when this handle is | |||||
* destructed, so associated resources can be released (e.g. | |||||
* computing handle) | |||||
* | |||||
* This function can be called at most once. | |||||
*/ | |||||
void set_destructor(const thin_function<void()> &d); | |||||
/*! | |||||
* \brief set a callback to be invoked when an operator is destructed | |||||
* \param[in,out] cb the callback function; it would be set to the | |||||
* previous callback function | |||||
*/ | |||||
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { | |||||
cb.swap(m_on_opr_destructed); | |||||
} | |||||
void on_opr_destructed(OperatorBase* opr); | |||||
/** | |||||
* \brief Create operator of Opr type. | |||||
*/ | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
/* | |||||
* ============================================================= | |||||
* Users should call functions below to query memory requirement. | |||||
* ============================================================= | |||||
*/ | |||||
/** | |||||
* \brief The internal data pointer of TensorND should be aligned to | |||||
* alignment_requirement() in bytes. | |||||
*/ | |||||
virtual size_t alignment_requirement() const; | |||||
//! get alignment in bytes for rows of image 2D tensor format | |||||
virtual size_t image2d_pitch_alignment() const; | |||||
//! get vendor type | |||||
virtual HandleVendorType vendor_type() const; | |||||
HandleType type() const { | |||||
return m_handle_type; | |||||
} | |||||
/** | |||||
* \brief Check is the layout satisfy cross device copy constraint. | |||||
* 1. The handle of the src and the dst is the same kind | |||||
* 2. The dst is continguous. | |||||
*/ | |||||
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); | |||||
private: | |||||
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||||
volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||||
megcoreComputingHandle_t m_computing_handle; | |||||
const HandleType m_handle_type; | |||||
thin_function<void()> m_destructor; | |||||
thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||||
Handle() = delete; | |||||
Handle(const Handle &rhs) = delete; | |||||
Handle &operator=(const Handle &rhs) = delete; | |||||
virtual ~Handle(); | |||||
/*! | |||||
* \brief Get the underlying megcore computing handle. | |||||
*/ | |||||
megcoreComputingHandle_t megcore_computing_handle() const { | |||||
return m_computing_handle; | |||||
} | |||||
/*! | |||||
* \brief set a callback function to be invoked when this handle is | |||||
* destructed, so associated resources can be released (e.g. | |||||
* computing handle) | |||||
* | |||||
* This function can be called at most once. | |||||
*/ | |||||
void set_destructor(const thin_function<void()>& d); | |||||
/*! | |||||
* \brief set a callback to be invoked when an operator is destructed | |||||
* \param[in,out] cb the callback function; it would be set to the | |||||
* previous callback function | |||||
*/ | |||||
void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) { | |||||
cb.swap(m_on_opr_destructed); | |||||
} | |||||
void on_opr_destructed(OperatorBase* opr); | |||||
/** | |||||
* \brief Create operator of Opr type. | |||||
*/ | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
/* | |||||
* ============================================================= | |||||
* Users should call functions below to query memory requirement. | |||||
* ============================================================= | |||||
*/ | |||||
/** | |||||
* \brief The internal data pointer of TensorND should be aligned to | |||||
* alignment_requirement() in bytes. | |||||
*/ | |||||
virtual size_t alignment_requirement() const; | |||||
//! get alignment in bytes for rows of image 2D tensor format | |||||
virtual size_t image2d_pitch_alignment() const; | |||||
//! get vendor type | |||||
virtual HandleVendorType vendor_type() const; | |||||
HandleType type() const { return m_handle_type; } | |||||
/** | |||||
* \brief Check is the layout satisfy cross device copy constraint. | |||||
* 1. The handle of the src and the dst is the same kind | |||||
* 2. The dst is continguous. | |||||
*/ | |||||
virtual bool check_cross_dev_copy_constraint(const TensorLayout& src); | |||||
private: | |||||
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||||
volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||||
megcoreComputingHandle_t m_computing_handle; | |||||
const HandleType m_handle_type; | |||||
thin_function<void()> m_destructor; | |||||
thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||||
Handle() = delete; | |||||
Handle(const Handle& rhs) = delete; | |||||
Handle& operator=(const Handle& rhs) = delete; | |||||
}; | }; | ||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
@@ -49,8 +49,9 @@ public: | |||||
mutable std::string m_input; | mutable std::string m_input; | ||||
public: | public: | ||||
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, | |||||
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) | |||||
Key(Handle* opr_handle, Algorithm::OprType opr_type, | |||||
const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, | |||||
const void* param_ptr = nullptr, size_t param_size = 0) | |||||
: m_handle{opr_handle}, | : m_handle{opr_handle}, | ||||
m_opr_type{static_cast<uint32_t>(opr_type)}, | m_opr_type{static_cast<uint32_t>(opr_type)}, | ||||
m_inp_layouts_ptr{inp_layouts_ptr}, | m_inp_layouts_ptr{inp_layouts_ptr}, | ||||
@@ -16,20 +16,19 @@ | |||||
* \brief iterate through small (usually used) ndim values | * \brief iterate through small (usually used) ndim values | ||||
*/ | */ | ||||
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | ||||
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) | |||||
cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__) | |||||
/*! | /*! | ||||
* \brief iterate through large (rarely used) ndim values | * \brief iterate through large (rarely used) ndim values | ||||
*/ | */ | ||||
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | ||||
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ | |||||
cb(7, ##__VA_ARGS__) | |||||
cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__) | |||||
/*! | /*! | ||||
* \brief iterate through all ndim values | * \brief iterate through all ndim values | ||||
*/ | */ | ||||
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||||
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ | |||||
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) | |||||
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||||
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \ | |||||
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__) | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -11,14 +11,14 @@ | |||||
// intentional no header guard here | // intentional no header guard here | ||||
#include "megdnn/handle.h" | #include "megdnn/handle.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/opr_result_defs.h" | #include "megdnn/opr_result_defs.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "./visibility_prologue.h" | #include "./visibility_prologue.h" | ||||
#include <limits> | |||||
#include <array> | #include <array> | ||||
#include <limits> | |||||
#ifndef _megdnn_in | #ifndef _megdnn_in | ||||
#define _megdnn_in | #define _megdnn_in | ||||
@@ -29,36 +29,37 @@ | |||||
#endif | #endif | ||||
#ifndef _megdnn_tensor_in | #ifndef _megdnn_tensor_in | ||||
#define _megdnn_tensor_in const TensorND & | |||||
#define _megdnn_tensor_in const TensorND& | |||||
#endif | #endif | ||||
#ifndef _megdnn_tensor_out | #ifndef _megdnn_tensor_out | ||||
#define _megdnn_tensor_out const TensorND & | |||||
#define _megdnn_tensor_out const TensorND& | |||||
#endif | #endif | ||||
#ifndef _megdnn_tensor_inout | #ifndef _megdnn_tensor_inout | ||||
#define _megdnn_tensor_inout const TensorND & | |||||
#define _megdnn_tensor_inout const TensorND& | |||||
#endif | #endif | ||||
#ifndef _megdnn_workspace | #ifndef _megdnn_workspace | ||||
#define _megdnn_workspace const Workspace & | |||||
#define _megdnn_workspace const Workspace& | |||||
#endif | #endif | ||||
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
public: \ | |||||
_opr_name(Handle *handle): _base_name(handle) {} \ | |||||
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
public: \ | |||||
_opr_name(Handle* handle) : _base_name(handle) {} | |||||
#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | ||||
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||||
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ | |||||
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||||
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; | |||||
#define DEF_OPR_PARAM(_pname) \ | |||||
public: \ | |||||
using Param = param::_pname; \ | |||||
Param& param() { return m_param; } \ | |||||
const Param& param() const { return m_param; } \ | |||||
protected: \ | |||||
Param m_param | |||||
#define DEF_OPR_PARAM(_pname) \ | |||||
public: \ | |||||
using Param = param::_pname; \ | |||||
Param& param() { return m_param; } \ | |||||
const Param& param() const { return m_param; } \ | |||||
\ | |||||
protected: \ | |||||
Param m_param | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -20,4 +20,3 @@ | |||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -16,25 +16,21 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace opr_result { | namespace opr_result { | ||||
struct Checksum { | |||||
uint32_t checksum; | |||||
union { | |||||
int32_t iv; | |||||
float fv; | |||||
} last_val; | |||||
bool operator == (const Checksum &rhs) const { | |||||
return checksum == rhs.checksum && | |||||
last_val.iv == rhs.last_val.iv; | |||||
} | |||||
bool operator != (const Checksum &rhs) const { | |||||
return !operator==(rhs); | |||||
} | |||||
}; | |||||
} // namespace opr_result | |||||
} // namespace megdnn | |||||
struct Checksum { | |||||
uint32_t checksum; | |||||
union { | |||||
int32_t iv; | |||||
float fv; | |||||
} last_val; | |||||
bool operator==(const Checksum& rhs) const { | |||||
return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv; | |||||
} | |||||
bool operator!=(const Checksum& rhs) const { return !operator==(rhs); } | |||||
}; | |||||
} // namespace opr_result | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -12,11 +12,11 @@ | |||||
#include "megdnn/oprs/cv.h" | #include "megdnn/oprs/cv.h" | ||||
#include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
#include "megdnn/oprs/imgproc.h" | |||||
#include "megdnn/oprs/linalg.h" | |||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
#include "megdnn/oprs/nn_int.h" | #include "megdnn/oprs/nn_int.h" | ||||
#include "megdnn/oprs/imgproc.h" | |||||
#include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
#include "megdnn/oprs/linalg.h" | |||||
template <typename Opr> | template <typename Opr> | ||||
struct OprArityTrait; | struct OprArityTrait; | ||||
@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); | |||||
#undef INST_ARITY | #undef INST_ARITY | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t { | |||||
INT8X8X16 = 1 << 4, | INT8X8X16 = 1 << 4, | ||||
INT16X16X32 = 1 << 5, | INT16X16X32 = 1 << 5, | ||||
INT4X4X16 = 1 << 6, | INT4X4X16 = 1 << 6, | ||||
QINT4x4x32 = 1 << 7, | |||||
QINT4x4x32 = 1 << 7, | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -195,16 +195,16 @@ public: | |||||
Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | ||||
Info info() const { | |||||
return {desc(), attribute()}; | |||||
} | |||||
Info info() const { return {desc(), attribute()}; } | |||||
template <typename T> | template <typename T> | ||||
static void serialize_write_pod(const T& val, std::string& result) { | static void serialize_write_pod(const T& val, std::string& result) { | ||||
static_assert(std::is_trivially_copyable<T>::value, | |||||
"type should be trivially copyable"); | |||||
static_assert(!std::is_pointer<T>::value, | |||||
"serialize pointer is unsafe in eager execution mode"); | |||||
static_assert( | |||||
std::is_trivially_copyable<T>::value, | |||||
"type should be trivially copyable"); | |||||
static_assert( | |||||
!std::is_pointer<T>::value, | |||||
"serialize pointer is unsafe in eager execution mode"); | |||||
result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | ||||
} | } | ||||
@@ -231,9 +231,8 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
static std::string deserialize_read_pod(const std::string& data, | |||||
size_t offset = 0, | |||||
size_t size = 0) { | |||||
static std::string deserialize_read_pod( | |||||
const std::string& data, size_t offset = 0, size_t size = 0) { | |||||
return std::string(data.data() + offset, size); | return std::string(data.data() + offset, size); | ||||
} | } | ||||
@@ -286,8 +285,8 @@ public: | |||||
* \param layouts origin layouts of the parent opr | * \param layouts origin layouts of the parent opr | ||||
* \param opr parent opr | * \param opr parent opr | ||||
*/ | */ | ||||
virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&, | |||||
const OperatorBase*) const { | |||||
virtual std::vector<SearchItem> get_subopr_list( | |||||
const TensorLayoutArray&, const OperatorBase*) const { | |||||
return {}; | return {}; | ||||
} | } | ||||
@@ -333,9 +332,7 @@ public: | |||||
ExecutionPolicy& execution_policy() { return m_execution_policy; } | ExecutionPolicy& execution_policy() { return m_execution_policy; } | ||||
const ExecutionPolicy& execution_policy() const { | |||||
return m_execution_policy; | |||||
} | |||||
const ExecutionPolicy& execution_policy() const { return m_execution_policy; } | |||||
virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | ||||
@@ -355,8 +352,8 @@ public: | |||||
using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
const TensorLayout& p1) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
const TensorLayout& p0, const TensorLayout& p1) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms(p0, p1)) { | for (auto&& algo : get_all_algorithms(p0, p1)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -364,8 +361,8 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -382,12 +379,11 @@ public: | |||||
*/ | */ | ||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
return get_algorithm_heuristic( | |||||
p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -408,8 +404,7 @@ protected: | |||||
*/ | */ | ||||
virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
}; | }; | ||||
@@ -423,9 +418,8 @@ public: | |||||
using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -433,9 +427,8 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -451,14 +444,13 @@ public: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
return get_algorithm_heuristic( | |||||
p0, p1, p2, workspace_limit_in_bytes, positive_attr, | |||||
negative_attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -467,11 +459,9 @@ protected: | |||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2) = 0; | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2) = 0; | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -480,10 +470,8 @@ protected: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
}; | }; | ||||
@@ -497,10 +485,9 @@ public: | |||||
using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -508,10 +495,9 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -527,14 +513,14 @@ public: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
return get_algorithm_heuristic( | |||||
p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr, | |||||
negative_attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -543,11 +529,11 @@ protected: | |||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -556,10 +542,9 @@ protected: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
}; | }; | ||||
@@ -573,11 +558,9 @@ public: | |||||
using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3, | |||||
const TensorLayout& p4) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -585,11 +568,9 @@ public: | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
const TensorLayout& p1, | |||||
const TensorLayout& p2, | |||||
const TensorLayout& p3, | |||||
const TensorLayout& p4) { | |||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4) { | |||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | ||||
ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
@@ -605,16 +586,14 @@ public: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||||
workspace_limit_in_bytes, positive_attr, | |||||
negative_attr) | |||||
return get_algorithm_heuristic( | |||||
p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr, | |||||
negative_attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -622,14 +601,12 @@ protected: | |||||
~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4) = 0; | |||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4) = 0; | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4) = 0; | |||||
/** | /** | ||||
* \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
@@ -638,11 +615,9 @@ protected: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, | |||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
}; | }; | ||||
@@ -657,9 +632,8 @@ public: | |||||
//! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info( | std::vector<AlgorithmInfo> get_all_algorithms_info( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) { | const TensorLayout& p6, const TensorLayout& p7) { | ||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | ||||
@@ -669,9 +643,8 @@ public: | |||||
} | } | ||||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) { | const TensorLayout& p6, const TensorLayout& p7) { | ||||
std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | ||||
@@ -687,17 +660,15 @@ public: | |||||
* The selected algorithm should not use workspace more than | * The selected algorithm should not use workspace more than | ||||
*/ | */ | ||||
AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||||
workspace_limit_in_bytes, positive_attr, | |||||
negative_attr) | |||||
return get_algorithm_heuristic( | |||||
p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes, | |||||
positive_attr, negative_attr) | |||||
->info(); | ->info(); | ||||
} | } | ||||
@@ -705,15 +676,13 @@ protected: | |||||
~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
//! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
virtual std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) = 0; | const TensorLayout& p6, const TensorLayout& p7) = 0; | ||||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7) = 0; | const TensorLayout& p6, const TensorLayout& p7) = 0; | ||||
/** | /** | ||||
@@ -723,12 +692,10 @@ protected: | |||||
* \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
*/ | */ | ||||
virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
const TensorLayout& p0, const TensorLayout& p1, | |||||
const TensorLayout& p2, const TensorLayout& p3, | |||||
const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
size_t workspace_limit_in_bytes = | |||||
std::numeric_limits<size_t>::max(), | |||||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
}; | }; | ||||
@@ -31,15 +31,17 @@ class FlipForward : public FlipBase { | |||||
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Flip = FlipForward; | using Flip = FlipForward; | ||||
@@ -56,15 +58,17 @@ class RotateForward : public RotateBase { | |||||
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Rotate = RotateForward; | using Rotate = RotateForward; | ||||
@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { | |||||
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using ROICopy = ROICopyForward; | using ROICopy = ROICopyForward; | ||||
@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { | |||||
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using CvtColor = CvtColorForward; | using CvtColor = CvtColorForward; | ||||
@@ -130,8 +138,9 @@ public: | |||||
using BorderMode = Param::BorderMode; | using BorderMode = Param::BorderMode; | ||||
protected: | protected: | ||||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, | |||||
const TensorLayout& dst); | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& trans, | |||||
const TensorLayout& dst); | |||||
std::string param_msg() const; | std::string param_msg() const; | ||||
int get_real_coord(int p, int len); | int get_real_coord(int p, int len); | ||||
}; | }; | ||||
@@ -148,15 +157,17 @@ public: | |||||
* \warning src, trans, border_value, dst should be contiguous | * \warning src, trans, border_value, dst should be contiguous | ||||
* The size of trans is N * 2 * 3 | * The size of trans is N * 2 * 3 | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& trans, | |||||
const TensorLayout& dst) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& trans, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& trans, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using WarpAffine = WarpAffineForward; | using WarpAffine = WarpAffineForward; | ||||
@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { | |||||
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using GaussianBlur = GaussianBlurForward; | using GaussianBlur = GaussianBlurForward; | ||||
@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { | |||||
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Resize = ResizeForward; | using Resize = ResizeForward; | ||||
@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { | |||||
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& mat) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& diff, const TensorLayout& mat) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& diff, const TensorLayout& mat, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& diff, const TensorLayout& mat, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
/** | /** | ||||
@@ -251,29 +268,32 @@ public: | |||||
using BorderMode = Param::BorderMode; | using BorderMode = Param::BorderMode; | ||||
protected: | protected: | ||||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& dst); | |||||
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||||
TensorLayout& dst); | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& dst); | |||||
void deduce_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||||
}; | }; | ||||
class RemapForward : public RemapBase { | class RemapForward : public RemapBase { | ||||
DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, | |||||
TensorLayout& dst); | |||||
void deduce_layout( | |||||
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& map_xy, | |||||
const TensorLayout& dst) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Remap = RemapForward; | using Remap = RemapForward; | ||||
@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { | |||||
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& map_xy, const TensorLayout& diff, | |||||
const TensorLayout& grad) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& map_xy, const TensorLayout& diff, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
}; | }; | ||||
class RemapBackwardMat : public RemapBase { | class RemapBackwardMat : public RemapBase { | ||||
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& map_xy, | |||||
const TensorLayout& diff, | |||||
const TensorLayout& grad) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& diff, const TensorLayout& grad) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& map_xy, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
class SeparableFilterBase : public OperatorBase { | class SeparableFilterBase : public OperatorBase { | ||||
@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { | |||||
DEF_OPR_PARAM(SeparableFilter); | DEF_OPR_PARAM(SeparableFilter); | ||||
protected: | protected: | ||||
void deduce_layout_fwd(const TensorLayout& src, | |||||
const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, TensorLayout& dst); | |||||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, | |||||
const TensorLayout& dst); | |||||
void deduce_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, TensorLayout& dst); | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, const TensorLayout& dst); | |||||
}; | }; | ||||
class SeparableFilterForward : public SeparableFilterBase { | class SeparableFilterForward : public SeparableFilterBase { | ||||
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||||
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, | |||||
const TensorLayout& dst) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||||
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& filter_x, | |||||
const TensorLayout& filter_y, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using SeparableFilter = SeparableFilterForward; | using SeparableFilter = SeparableFilterForward; | ||||
@@ -13,173 +13,162 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
class WarpPerspectiveBase: public OperatorBase { | |||||
class WarpPerspectiveBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | ||||
DEF_OPR_PARAM(WarpPerspective); | DEF_OPR_PARAM(WarpPerspective); | ||||
public: | |||||
using InterpolationMode = Param::InterpolationMode; | |||||
using BorderMode = Param::BorderMode; | |||||
protected: | |||||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||||
const TensorLayout &dst) { | |||||
check_layout_fwd(src, mat, {}, dst); | |||||
} | |||||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, const TensorLayout &dst); | |||||
std::string param_msg() const; | |||||
int get_real_coord(int p, int len); | |||||
public: | |||||
using InterpolationMode = Param::InterpolationMode; | |||||
using BorderMode = Param::BorderMode; | |||||
protected: | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||||
check_layout_fwd(src, mat, {}, dst); | |||||
} | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst); | |||||
std::string param_msg() const; | |||||
int get_real_coord(int p, int len); | |||||
}; | }; | ||||
class WarpPerspectiveForward: public WarpPerspectiveBase { | |||||
class WarpPerspectiveForward : public WarpPerspectiveBase { | |||||
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | ||||
public: | |||||
/** | |||||
* \param[in] src (n, channel, in_height, in_width) | |||||
* \param[in] mat (n, 3, 3) | |||||
* \param[out] dst (n, channel, out_height, out_width) | |||||
* | |||||
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||||
* | |||||
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||||
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||||
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||||
* | |||||
* src and dst can have different shapes, as long as their n and c agree. | |||||
* src, mat and dst should be contiguous. | |||||
*/ | |||||
void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mat, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
exec(src, mat, {}, dst, workspace); | |||||
} | |||||
/** | |||||
* \p src should have batch size m, and \p mat and \p mat_idx should | |||||
* both have batch size n. Each item in \p mat_idx must be in the range | |||||
* of [0, m-1]. | |||||
* | |||||
* \param mat_idx the indices of input image that each matrix in \p mat | |||||
* should act on. It can also be empty and in such case \p mat | |||||
* should have the same batch size as \p src. | |||||
*/ | |||||
virtual void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &dst) { | |||||
return get_workspace_in_bytes(src, mat, {}, dst); | |||||
} | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &dst, | |||||
size_t workspace_in_bytes); | |||||
public: | |||||
/** | |||||
* \param[in] src (n, channel, in_height, in_width) | |||||
* \param[in] mat (n, 3, 3) | |||||
* \param[out] dst (n, channel, out_height, out_width) | |||||
* | |||||
* \see | |||||
* http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||||
* | |||||
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||||
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||||
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||||
* | |||||
* src and dst can have different shapes, as long as their n and c agree. | |||||
* src, mat and dst should be contiguous. | |||||
*/ | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
exec(src, mat, {}, dst, workspace); | |||||
} | |||||
/** | |||||
* \p src should have batch size m, and \p mat and \p mat_idx should | |||||
* both have batch size n. Each item in \p mat_idx must be in the range | |||||
* of [0, m-1]. | |||||
* | |||||
* \param mat_idx the indices of input image that each matrix in \p mat | |||||
* should act on. It can also be empty and in such case \p mat | |||||
* should have the same batch size as \p src. | |||||
*/ | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||||
return get_workspace_in_bytes(src, mat, {}, dst); | |||||
} | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
void check_exec_allow_nhwc_mat_idx( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using WarpPerspective = WarpPerspectiveForward; | using WarpPerspective = WarpPerspectiveForward; | ||||
class WarpPerspectiveBackwardData: public WarpPerspectiveBase { | |||||
class WarpPerspectiveBackwardData : public WarpPerspectiveBase { | |||||
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | ||||
public: | |||||
/** | |||||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[out] grad the backpropagated gradient wrt. src | |||||
* \param[out] workspace temporary workspace to perform backward | |||||
*/ | |||||
void exec(_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
exec(mat, {}, diff, grad, workspace); | |||||
} | |||||
virtual void exec(_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes(const TensorLayout &mat, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) { | |||||
return get_workspace_in_bytes(mat, {}, diff, grad); | |||||
} | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad, | |||||
size_t workspace_in_bytes); | |||||
public: | |||||
/** | |||||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[out] grad the backpropagated gradient wrt. src | |||||
* \param[out] workspace temporary workspace to perform backward | |||||
*/ | |||||
void exec( | |||||
_megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
exec(mat, {}, diff, grad, workspace); | |||||
} | |||||
virtual void exec( | |||||
_megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& mat, const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return get_workspace_in_bytes(mat, {}, diff, grad); | |||||
} | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& mat, const TensorLayout& mat_idx, | |||||
const TensorLayout& diff, const TensorLayout& grad) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& mat, const TensorLayout& mat_idx, | |||||
const TensorLayout& diff, const TensorLayout& grad, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||||
class WarpPerspectiveBackwardMat : public WarpPerspectiveBase { | |||||
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | ||||
public: | |||||
/** | |||||
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[out] grad the backpropagated gradient wrt. mat | |||||
* \param[out] workspace temporary workspace to perform backward | |||||
*/ | |||||
void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) { | |||||
exec(src, mat, {}, diff, grad, workspace); | |||||
} | |||||
virtual void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mat, | |||||
_megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) { | |||||
return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||||
} | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &src, | |||||
const TensorLayout &mat, | |||||
const TensorLayout &mat_idx, | |||||
const TensorLayout &diff, | |||||
const TensorLayout &grad, | |||||
size_t workspace_in_bytes); | |||||
public: | |||||
/** | |||||
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[out] grad the backpropagated gradient wrt. mat | |||||
* \param[out] workspace temporary workspace to perform backward | |||||
*/ | |||||
void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
exec(src, mat, {}, diff, grad, workspace); | |||||
} | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff, | |||||
const TensorLayout& grad) { | |||||
return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||||
} | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& diff, | |||||
const TensorLayout& grad) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& mat, | |||||
const TensorLayout& mat_idx, const TensorLayout& diff, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
}; | }; | ||||
class DctChannelSelectForward : public OperatorBase { | class DctChannelSelectForward : public OperatorBase { | ||||
@@ -194,37 +183,32 @@ public: | |||||
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | ||||
* \param[out] workspace temporary workspace to perform forward | * \param[out] workspace temporary workspace to perform forward | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in src, | |||||
_megdnn_tensor_in mask_offset, | |||||
_megdnn_tensor_in mask_val, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, | |||||
const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, | |||||
TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, | |||||
const TensorLayout& dst) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||||
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& src, const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, const TensorLayout& dst) = 0; | |||||
protected: | protected: | ||||
void check_layout_fwd(const TensorLayout& src, | |||||
const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, | |||||
const TensorLayout& dst); | |||||
void deduce_layout_fwd(const TensorLayout& src, | |||||
const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, | |||||
TensorLayout& dst); | |||||
void check_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, const TensorLayout& dst); | |||||
void deduce_layout_fwd( | |||||
const TensorLayout& src, const TensorLayout& mask_offset, | |||||
const TensorLayout& mask_val, TensorLayout& dst); | |||||
std::string param_msg() const; | std::string param_msg() const; | ||||
}; | }; | ||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
#include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
@@ -33,22 +33,22 @@ public: | |||||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | * op(A) = A if transposeA is false, otherwise op(A) = A^t. | ||||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | * op(B) = B if transposeB is false, otherwise op(B) = B^t. | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
void deduce_dtype(DType A, DType B, DType &C); | |||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_dtype(DType A, DType B, DType& C); | |||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
static Algorithm::OprType get_opr_type() { | static Algorithm::OprType get_opr_type() { | ||||
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | ||||
} | } | ||||
protected: | protected: | ||||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
const TensorLayout& C, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using BatchedMatrixMul = BatchedMatrixMulForward; | using BatchedMatrixMul = BatchedMatrixMulForward; | ||||
@@ -70,24 +70,24 @@ public: | |||||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | * op(A) = A if transposeA is false, otherwise op(A) = A^t. | ||||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | * op(B) = B if transposeB is false, otherwise op(B) = B^t. | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_dtype(DType A, DType B, DType& C); | void deduce_dtype(DType A, DType B, DType& C); | ||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) = 0; | |||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
static size_t pack_size (const Param::Format format); | |||||
static size_t pack_size(const Param::Format format); | |||||
static Algorithm::OprType get_opr_type() { | static Algorithm::OprType get_opr_type() { | ||||
return Algorithm::OprType::MATRIX_MUL_FORWARD; | return Algorithm::OprType::MATRIX_MUL_FORWARD; | ||||
} | } | ||||
protected: | protected: | ||||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
const TensorLayout& C, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using MatrixMul = MatrixMulForward; | using MatrixMul = MatrixMulForward; | ||||
@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { | |||||
DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst); | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst); | |||||
protected: | protected: | ||||
/*! | /*! | ||||
@@ -116,8 +116,7 @@ protected: | |||||
* | * | ||||
* Note that \p batch and \p n can be null | * Note that \p batch and \p n can be null | ||||
*/ | */ | ||||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||||
size_t* n); | |||||
static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n); | |||||
/*! | /*! | ||||
* \brief canonize and validate input params for exec() impls | * \brief canonize and validate input params for exec() impls | ||||
@@ -125,11 +124,12 @@ protected: | |||||
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not | * Since get_workspace_in_bytes() would be called, \p batch and \p n can not | ||||
* be null | * be null | ||||
*/ | */ | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
_megdnn_workspace workspace, size_t* batch, size_t* n); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
_megdnn_workspace workspace, size_t* batch, size_t* n); | |||||
virtual size_t get_workspace_in_bytes(size_t batch, size_t n, | |||||
size_t dtype_size) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
size_t batch, size_t n, size_t dtype_size) = 0; | |||||
}; | }; | ||||
//! inter-product of two vectors | //! inter-product of two vectors | ||||
@@ -147,17 +147,17 @@ public: | |||||
* A, B, C must be contiguous. A and B must have the same 1-dimensional | * A, B, C must be contiguous. A and B must have the same 1-dimensional | ||||
* shape and non-negative strides. C must be scalar. | * shape and non-negative strides. C must be scalar. | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
const TensorLayout& B, | |||||
const TensorLayout& C) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
const TensorLayout& C, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Dot = DotForward; | using Dot = DotForward; | ||||
@@ -193,23 +193,24 @@ public: | |||||
* if compute_uv is false (default to true). | * if compute_uv is false (default to true). | ||||
* | * | ||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, | |||||
_megdnn_tensor_out s, _megdnn_tensor_out vt, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& u, | |||||
TensorLayout& s, TensorLayout& vt); | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& u, const TensorLayout& s, | |||||
const TensorLayout& vt); | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s, | |||||
_megdnn_tensor_out vt, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& src, TensorLayout& u, TensorLayout& s, | |||||
TensorLayout& vt); | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||||
const TensorLayout& vt); | |||||
protected: | protected: | ||||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||||
size_t* m, size_t* n); | |||||
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, | |||||
size_t dtype_size) = 0; | |||||
void check_exec(const TensorLayout& src, const TensorLayout& u, | |||||
const TensorLayout& s, const TensorLayout& vt, | |||||
size_t workspace_in_bytes); | |||||
static void canonize_params( | |||||
const TensorLayout& layout, size_t* batch, size_t* m, size_t* n); | |||||
virtual size_t get_workspace_in_bytes( | |||||
size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0; | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||||
const TensorLayout& vt, size_t workspace_in_bytes); | |||||
}; | }; | ||||
using SVD = SVDForward; | using SVD = SVDForward; | ||||
@@ -36,7 +36,7 @@ public: | |||||
struct ModeTrait { | struct ModeTrait { | ||||
uint32_t arity = 0; //!< number of inputs needed | uint32_t arity = 0; //!< number of inputs needed | ||||
CheckDtypeFunc check_inp[MAX_ARITY]; | CheckDtypeFunc check_inp[MAX_ARITY]; | ||||
SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||||
SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||||
bool need_specify_out_dtype = | bool need_specify_out_dtype = | ||||
false; //!< the dtype should be setup externally, otherwise | false; //!< the dtype should be setup externally, otherwise | ||||
//!< would be inferred by check_out(dtype, false) | //!< would be inferred by check_out(dtype, false) | ||||
@@ -46,13 +46,10 @@ public: | |||||
static const ModeTrait& from_mode(Mode mode); | static const ModeTrait& from_mode(Mode mode); | ||||
}; | }; | ||||
virtual void exec(_megdnn_in const TensorNDArray& src, | |||||
_megdnn_tensor_out dst) = 0; | |||||
virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; | |||||
//! get trait of current mode | //! get trait of current mode | ||||
const ModeTrait& mode_trait() const { | |||||
return ModeTrait::from_mode(m_param.mode); | |||||
} | |||||
const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } | |||||
//! deduce output layout | //! deduce output layout | ||||
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | ||||
@@ -60,8 +57,8 @@ public: | |||||
protected: | protected: | ||||
//! throw exception if incorrect layout; broadcast input shape to | //! throw exception if incorrect layout; broadcast input shape to | ||||
//! output shape | //! output shape | ||||
void check_layout_and_broadcast(const TensorLayoutPtrArray& src, | |||||
const TensorLayout& dst); | |||||
void check_layout_and_broadcast( | |||||
const TensorLayoutPtrArray& src, const TensorLayout& dst); | |||||
}; | }; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -15,84 +15,97 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
//! base class for random number generators | //! base class for random number generators | ||||
class RNGBase: public OperatorBase { | |||||
class RNGBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | ||||
public: | |||||
virtual void exec(_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||||
protected: | |||||
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||||
public: | |||||
virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; | |||||
protected: | |||||
virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0; | |||||
}; | }; | ||||
//! sample from poisson distribution | //! sample from poisson distribution | ||||
class PoissonRNG: public OperatorBase { | |||||
class PoissonRNG : public OperatorBase { | |||||
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | ||||
DEF_OPR_PARAM(PoissonRNG); | DEF_OPR_PARAM(PoissonRNG); | ||||
public: | |||||
virtual void exec(_megdnn_tensor_in lam, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||||
const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||||
size_t workspace_in_bytes); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& lam, const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& lam, const TensorLayout& dst, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from beta distribution | //! sample from beta distribution | ||||
class BetaRNG: public OperatorBase { | |||||
class BetaRNG : public OperatorBase { | |||||
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | ||||
DEF_OPR_PARAM(BetaRNG); | DEF_OPR_PARAM(BetaRNG); | ||||
public: | |||||
virtual void exec(_megdnn_tensor_in alpha, | |||||
_megdnn_tensor_in beta, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||||
const TensorLayout &beta, const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||||
const TensorLayout &dst, size_t workspace_in_bytes); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& alpha, const TensorLayout& beta, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& alpha, const TensorLayout& beta, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from gamma distribution | //! sample from gamma distribution | ||||
class GammaRNG: public OperatorBase { | |||||
class GammaRNG : public OperatorBase { | |||||
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | ||||
DEF_OPR_PARAM(GammaRNG); | DEF_OPR_PARAM(GammaRNG); | ||||
public: | |||||
virtual void exec(_megdnn_tensor_in shape, | |||||
_megdnn_tensor_in scale, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||||
const TensorLayout &scale, const TensorLayout &dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||||
const TensorLayout &dst, size_t workspace_in_bytes); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& shape, const TensorLayout& scale, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& shape, const TensorLayout& scale, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from uniform distribution on the interval (0, 1] | //! sample from uniform distribution on the interval (0, 1] | ||||
class UniformRNG: public RNGBase { | |||||
class UniformRNG : public RNGBase { | |||||
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | ||||
DEF_OPR_PARAM(UniformRNG); | DEF_OPR_PARAM(UniformRNG); | ||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
protected: | |||||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
//! sample from gaussian distribution | //! sample from gaussian distribution | ||||
class GaussianRNG: public RNGBase { | |||||
class GaussianRNG : public RNGBase { | |||||
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | ||||
DEF_OPR_PARAM(GaussianRNG); | DEF_OPR_PARAM(GaussianRNG); | ||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
protected: | |||||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
class PermutationRNG: public RNGBase { | |||||
class PermutationRNG : public RNGBase { | |||||
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | ||||
DEF_OPR_PARAM(PermutationRNG); | DEF_OPR_PARAM(PermutationRNG); | ||||
protected: | |||||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
protected: | |||||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | }; | ||||
class ShuffleRNGForward : public OperatorBase { | class ShuffleRNGForward : public OperatorBase { | ||||
@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { | |||||
DEF_OPR_PARAM(ShuffleRNG); | DEF_OPR_PARAM(ShuffleRNG); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
_megdnn_tensor_out indices, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||||
TensorLayout& indices); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
const TensorLayout& dst, | |||||
const TensorLayout& indices) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& indices) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& indices, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& src, const TensorLayout& dst, | |||||
const TensorLayout& indices, size_t workspace_in_bytes); | |||||
}; | }; | ||||
using ShuffleRNG = ShuffleRNGForward; | using ShuffleRNG = ShuffleRNGForward; | ||||
@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { | |||||
DEF_OPR_PARAM(ShuffleRNG); | DEF_OPR_PARAM(ShuffleRNG); | ||||
public: | public: | ||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& indices, | |||||
const TensorLayout& grad) = 0; | |||||
virtual void exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& diff, const TensorLayout& indices, | |||||
const TensorLayout& grad) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
void check_exec( | |||||
const TensorLayout& diff, const TensorLayout& indices, | |||||
const TensorLayout& grad, size_t workspace_in_bytes); | |||||
}; | }; | ||||
/*! | /*! | ||||
* \brief sleep for specific time on the computing device; useful for testing | * \brief sleep for specific time on the computing device; useful for testing | ||||
* async problems | * async problems | ||||
*/ | */ | ||||
class SleepForward: public OperatorBase { | |||||
class SleepForward : public OperatorBase { | |||||
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | ||||
DEF_OPR_PARAM(Sleep); | DEF_OPR_PARAM(Sleep); | ||||
public: | |||||
virtual void exec() = 0; | |||||
public: | |||||
virtual void exec() = 0; | |||||
}; | }; | ||||
using Sleep = SleepForward; | using Sleep = SleepForward; | ||||
@@ -149,20 +165,19 @@ using Sleep = SleepForward; | |||||
* | * | ||||
* data must be a one-dimensional contiguous tensor with dtype byte | * data must be a one-dimensional contiguous tensor with dtype byte | ||||
*/ | */ | ||||
class ChecksumForward: public OperatorBase { | |||||
class ChecksumForward : public OperatorBase { | |||||
DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | ||||
public: | |||||
using Result = opr_result::Checksum; | |||||
public: | |||||
using Result = opr_result::Checksum; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; | |||||
virtual Result exec(_megdnn_tensor_in data, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); | |||||
protected: | |||||
void check_exec(const TensorLayout& layout, size_t workspace_in_bytes); | |||||
}; | }; | ||||
using Checksum = ChecksumForward; | using Checksum = ChecksumForward; | ||||
@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { | |||||
DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | ||||
public: | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, | |||||
const TensorLayout& layout2) = 0; | |||||
public: | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& layout1, const TensorLayout& layout2) = 0; | |||||
virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||||
_megdnn_workspace workspace) = 0; | |||||
virtual float exec( | |||||
_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||||
_megdnn_workspace workspace) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& layout1, | |||||
const TensorLayout& layout2, size_t workspace_in_bytes); | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& layout1, const TensorLayout& layout2, | |||||
size_t workspace_in_bytes); | |||||
}; | }; | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format); | |||||
bool check_bias_share_in_channel( | |||||
const TensorLayout& bias, const param::ConvBias::Format format); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -18,9 +18,9 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
enum class TensorFormat::Type { | enum class TensorFormat::Type { | ||||
DEFAULT = 0, //!< see DefaultTensorFormat | |||||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||||
LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
DEFAULT = 0, //!< see DefaultTensorFormat | |||||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||||
LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
}; | }; | ||||
class TensorFormat::ImplBase { | class TensorFormat::ImplBase { | ||||
@@ -33,8 +33,7 @@ public: | |||||
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | ||||
virtual TensorLayout collapse_contiguous_spec( | |||||
const TensorLayout& layout) const = 0; | |||||
virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0; | |||||
virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | ||||
@@ -79,8 +78,7 @@ public: | |||||
*/ | */ | ||||
bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
TensorLayout collapse_contiguous_spec( | |||||
const TensorLayout& layout) const override; | |||||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | TensorLayout::Span span_spec(const TensorLayout& layout) const override; | ||||
@@ -88,8 +86,7 @@ public: | |||||
void serialize_append(std::string& result) const override; | void serialize_append(std::string& result) const override; | ||||
static TensorFormat make(); | static TensorFormat make(); | ||||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
size_t size); | |||||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
}; | }; | ||||
namespace detail { | namespace detail { | ||||
@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||||
size_t m_align_axis, m_align_size_in_elements_log2; | size_t m_align_axis, m_align_size_in_elements_log2; | ||||
protected: | protected: | ||||
Image2DTensorFormatBase(Type type, size_t align_axis, | |||||
size_t align_size_in_elements); | |||||
Image2DTensorFormatBase( | |||||
Type type, size_t align_axis, size_t align_size_in_elements); | |||||
virtual ~Image2DTensorFormatBase() = default; | virtual ~Image2DTensorFormatBase() = default; | ||||
public: | public: | ||||
@@ -129,9 +126,7 @@ public: | |||||
size_t align_axis() const { return m_align_axis; } | size_t align_axis() const { return m_align_axis; } | ||||
size_t align_size_in_elements_log2() const { | |||||
return m_align_size_in_elements_log2; | |||||
} | |||||
size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; } | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
@@ -145,6 +140,7 @@ public: | |||||
size_t image_height(const TensorLayout& layout) const; | size_t image_height(const TensorLayout& layout) const; | ||||
void serialize_append(std::string& result) const override; | void serialize_append(std::string& result) const override; | ||||
protected: | protected: | ||||
struct SerializePack { | struct SerializePack { | ||||
uint8_t align_axis; | uint8_t align_axis; | ||||
@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||||
* align COUNT, but mdl needs align size in byte, which equal to | * align COUNT, but mdl needs align size in byte, which equal to | ||||
* (image_width algin count) * sizeof(data_type) * pixel_size | * (image_width algin count) * sizeof(data_type) * pixel_size | ||||
*/ | */ | ||||
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, | |||||
const TensorLayout& layout) const; | |||||
size_t image_pitch_alignment_in_bytes( | |||||
size_t align_size_in_elements, const TensorLayout& layout) const; | |||||
protected: | protected: | ||||
Image2DPackedTensorFormatBase(Type type, size_t align_axis, | |||||
size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type) | |||||
: detail::Image2DTensorFormatBase(type, align_axis, | |||||
align_size_in_elements), | |||||
Image2DPackedTensorFormatBase( | |||||
Type type, size_t align_axis, size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type) | |||||
: detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements), | |||||
m_vendor_type(vendor_type) {} | m_vendor_type(vendor_type) {} | ||||
virtual ~Image2DPackedTensorFormatBase() = default; | virtual ~Image2DPackedTensorFormatBase() = default; | ||||
@@ -197,13 +192,12 @@ public: | |||||
bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
TensorLayout collapse_contiguous_spec( | |||||
const TensorLayout& layout) const override; | |||||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
}; | }; | ||||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | ||||
/*! | /*! | ||||
* \brief used for tensors storing lowbit data | |||||
* \brief used for tensors storing lowbit data | |||||
* | * | ||||
* \param m_size_nbits size in bits of elements in the tensor | * \param m_size_nbits size in bits of elements in the tensor | ||||
* \param m_align_size_in_bits aligned size in bits | * \param m_align_size_in_bits aligned size in bits | ||||
@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { | |||||
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | ||||
protected: //? | protected: //? | ||||
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, | |||||
size_t align_size_in_bits); | |||||
LowbitsAlignedTensorFormatBase( | |||||
Type type, size_t size_nbits, size_t align_size_in_bits); | |||||
virtual ~LowbitsAlignedTensorFormatBase() = default; | virtual ~LowbitsAlignedTensorFormatBase() = default; | ||||
public: | public: | ||||
size_t align_size_in_bits() const { return m_align_size_in_bits; } | size_t align_size_in_bits() const { return m_align_size_in_bits; } | ||||
size_t size_nbits() const { return m_size_nbits; } | size_t size_nbits() const { return m_size_nbits; } | ||||
std::string to_string() const override; | std::string to_string() const override; | ||||
@@ -238,8 +232,8 @@ public: | |||||
bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
TensorLayout collapse_contiguous_spec( | |||||
const TensorLayout& layout) const override; | |||||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
protected: | protected: | ||||
struct SerializePack { | struct SerializePack { | ||||
uint8_t size_nbits; | uint8_t size_nbits; | ||||
@@ -254,16 +248,14 @@ protected: | |||||
* | * | ||||
* This is used for OpenCL. | * This is used for OpenCL. | ||||
*/ | */ | ||||
class Image2DPack4TensorFormat final | |||||
: public detail::Image2DPack4TensorFormatBase { | |||||
class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase { | |||||
public: | public: | ||||
static constexpr Type TYPE = Type::IMAGE2D_PACK4; | static constexpr Type TYPE = Type::IMAGE2D_PACK4; | ||||
//! for internal usage or test purposes | //! for internal usage or test purposes | ||||
static TensorFormat make_raw(size_t align_axis, | |||||
size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type = | |||||
Handle::HandleVendorType::NOT_SPEC); | |||||
static TensorFormat make_raw( | |||||
size_t align_axis, size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC); | |||||
static TensorFormat make(size_t align_axis, const Handle* handle); | static TensorFormat make(size_t align_axis, const Handle* handle); | ||||
@@ -273,13 +265,11 @@ public: | |||||
* Note that the alignment may be different if deserialized on another | * Note that the alignment may be different if deserialized on another | ||||
* handle | * handle | ||||
*/ | */ | ||||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
size_t size); | |||||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
static bool is_valid_image(const TensorLayout& layout) { | static bool is_valid_image(const TensorLayout& layout) { | ||||
if (layout.format.type() == TYPE) { | if (layout.format.type() == TYPE) { | ||||
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( | |||||
layout); | |||||
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout); | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
@@ -288,8 +278,9 @@ public: | |||||
TensorFormat change_axis(size_t axis) const override; | TensorFormat change_axis(size_t axis) const override; | ||||
private: | private: | ||||
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type) | |||||
Image2DPack4TensorFormat( | |||||
size_t align_axis, size_t align_size_in_elements, | |||||
Handle::HandleVendorType vendor_type) | |||||
: detail::Image2DPack4TensorFormatBase( | : detail::Image2DPack4TensorFormatBase( | ||||
TYPE, align_axis, align_size_in_elements, vendor_type) {} | TYPE, align_axis, align_size_in_elements, vendor_type) {} | ||||
}; | }; | ||||
@@ -306,13 +297,12 @@ public: | |||||
static TensorFormat make(size_t size_nbits); | static TensorFormat make(size_t size_nbits); | ||||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
size_t size); | |||||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
static bool is_valid_layout(const TensorLayout& layout) { | static bool is_valid_layout(const TensorLayout& layout) { | ||||
if (layout.format.type() == TYPE) { | if (layout.format.type() == TYPE) { | ||||
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() | |||||
.assert_valid(layout); | |||||
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid( | |||||
layout); | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
@@ -320,8 +310,7 @@ public: | |||||
private: | private: | ||||
LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | ||||
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, | |||||
BYTE_IN_BITS) {} | |||||
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {} | |||||
}; | }; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -167,13 +167,11 @@ public: | |||||
TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | ||||
Iter begin() const { | |||||
return Iter::make(const_cast<TensorND&>(m_tensor), 0); | |||||
} | |||||
Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); } | |||||
Iter end() const { | Iter end() const { | ||||
return Iter::make(const_cast<TensorND&>(m_tensor), | |||||
m_tensor.layout.total_nr_elems()); | |||||
return Iter::make( | |||||
const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems()); | |||||
} | } | ||||
}; | }; | ||||
/*! | /*! | ||||
@@ -11,19 +11,19 @@ | |||||
#pragma once | #pragma once | ||||
#include <type_traits> | |||||
#include <cstdlib> | |||||
#include <functional> | #include <functional> | ||||
#include <utility> | |||||
#include <memory> | #include <memory> | ||||
#include <cstdlib> | |||||
#include <type_traits> | |||||
#include <utility> | |||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
template<typename Signature> | |||||
template <typename Signature> | |||||
using thin_function = ::std::function<Signature>; | using thin_function = ::std::function<Signature>; | ||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
@@ -58,18 +58,16 @@ protected: | |||||
m_end_ptr(first_elm), | m_end_ptr(first_elm), | ||||
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | ||||
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, | |||||
size_t type_size); | |||||
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size); | |||||
public: | public: | ||||
size_t size_in_bytes() const { | size_t size_in_bytes() const { | ||||
return size_t(static_cast<char*>(m_end_ptr) - | |||||
static_cast<char*>(m_begin_ptr)); | |||||
return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr)); | |||||
} | } | ||||
size_t capacity_in_bytes() const { | size_t capacity_in_bytes() const { | ||||
return size_t(static_cast<char*>(m_capacity_ptr) - | |||||
static_cast<char*>(m_begin_ptr)); | |||||
return size_t( | |||||
static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr)); | |||||
} | } | ||||
bool empty() const { return m_begin_ptr == m_end_ptr; } | bool empty() const { return m_begin_ptr == m_end_ptr; } | ||||
@@ -85,20 +83,15 @@ private: | |||||
U m_first_elm; | U m_first_elm; | ||||
protected: | protected: | ||||
SmallVectorTemplateCommon(size_t size) | |||||
: SmallVectorBase(&m_first_elm, size) {} | |||||
SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {} | |||||
void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | ||||
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | ||||
} | } | ||||
bool is_small() { | |||||
return m_begin_ptr == static_cast<const void*>(&m_first_elm); | |||||
} | |||||
bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); } | |||||
void reset_to_small() { | |||||
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; | |||||
} | |||||
void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; } | |||||
void set_end(T* p) { m_end_ptr = p; } | void set_end(T* p) { m_end_ptr = p; } | ||||
@@ -128,20 +121,12 @@ protected: | |||||
public: | public: | ||||
// forwarding iterator creation | // forwarding iterator creation | ||||
iterator begin() { return static_cast<iterator>(m_begin_ptr); } | iterator begin() { return static_cast<iterator>(m_begin_ptr); } | ||||
const_iterator begin() const { | |||||
return static_cast<const_iterator>(m_begin_ptr); | |||||
} | |||||
const_iterator cbegin() const { | |||||
return static_cast<const_iterator>(m_begin_ptr); | |||||
} | |||||
const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||||
const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||||
iterator end() { return static_cast<iterator>(m_end_ptr); } | iterator end() { return static_cast<iterator>(m_end_ptr); } | ||||
const_iterator end() const { | |||||
return static_cast<const_iterator>(m_end_ptr); | |||||
} | |||||
const_iterator cend() const { | |||||
return static_cast<const_iterator>(m_end_ptr); | |||||
} | |||||
const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); } | |||||
const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); } | |||||
reference at(size_type idx) { | reference at(size_type idx) { | ||||
if (idx >= size()) { | if (idx >= size()) { | ||||
@@ -167,13 +152,9 @@ public: | |||||
// reverse iterator creation method. | // reverse iterator creation method. | ||||
reverse_iterator rbegin() { return reverse_iterator(end()); } | reverse_iterator rbegin() { return reverse_iterator(end()); } | ||||
const_reverse_iterator rbegin() const { | |||||
return const_reverse_iterator(end()); | |||||
} | |||||
const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } | |||||
reverse_iterator rend() { return reverse_iterator(begin()); } | reverse_iterator rend() { return reverse_iterator(begin()); } | ||||
const_reverse_iterator rend() const { | |||||
return const_reverse_iterator(begin()); | |||||
} | |||||
const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } | |||||
pointer data() { return pointer(begin()); } | pointer data() { return pointer(begin()); } | ||||
const_pointer data() const { return const_pointer(begin()); } | const_pointer data() const { return const_pointer(begin()); } | ||||
@@ -207,8 +188,8 @@ protected: | |||||
template <typename It1, typename It2> | template <typename It1, typename It2> | ||||
static void uninitialized_move(It1 first, It1 last, It2 dest) { | static void uninitialized_move(It1 first, It1 last, It2 dest) { | ||||
std::uninitialized_copy(std::make_move_iterator(first), | |||||
std::make_move_iterator(last), dest); | |||||
std::uninitialized_copy( | |||||
std::make_move_iterator(first), std::make_move_iterator(last), dest); | |||||
} | } | ||||
template <typename It1, typename It2> | template <typename It1, typename It2> | ||||
@@ -293,9 +274,7 @@ protected: | |||||
memcpy(dest, first, (last - first) * sizeof(T)); | memcpy(dest, first, (last - first) * sizeof(T)); | ||||
} | } | ||||
void grow(size_t min_sz = 0) { | |||||
this->grow_pod(min_sz * sizeof(T), sizeof(T)); | |||||
} | |||||
void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); } | |||||
public: | public: | ||||
void push_back(const T& _elm) { | void push_back(const T& _elm) { | ||||
@@ -318,8 +297,7 @@ public: | |||||
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | ||||
*/ | */ | ||||
template <typename T> | template <typename T> | ||||
class SmallVectorImpl | |||||
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||||
class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||||
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | ||||
public: | public: | ||||
@@ -329,8 +307,7 @@ public: | |||||
protected: | protected: | ||||
explicit SmallVectorImpl(unsigned n) | explicit SmallVectorImpl(unsigned n) | ||||
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { | |||||
} | |||||
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {} | |||||
public: | public: | ||||
SmallVectorImpl(const SmallVectorImpl&) = delete; | SmallVectorImpl(const SmallVectorImpl&) = delete; | ||||
@@ -354,8 +331,7 @@ public: | |||||
} else if (n > this->size()) { | } else if (n > this->size()) { | ||||
if (this->capacity() < n) | if (this->capacity() < n) | ||||
this->grow(n); | this->grow(n); | ||||
for (auto it = this->end(), end = this->begin() + n; it != end; | |||||
++it) | |||||
for (auto it = this->end(), end = this->begin() + n; it != end; ++it) | |||||
new (&*it) T(); | new (&*it) T(); | ||||
this->set_end(this->begin() + n); | this->set_end(this->begin() + n); | ||||
} | } | ||||
@@ -389,10 +365,11 @@ public: | |||||
void swap(SmallVectorImpl<T>& rhs); | void swap(SmallVectorImpl<T>& rhs); | ||||
/// Add the specified range to the end of the SmallVector. | /// Add the specified range to the end of the SmallVector. | ||||
template <typename in_iter, | |||||
typename = typename std::enable_if<std::is_convertible< | |||||
typename std::iterator_traits<in_iter>::iterator_category, | |||||
std::input_iterator_tag>::value>::type> | |||||
template < | |||||
typename in_iter, | |||||
typename = typename std::enable_if<std::is_convertible< | |||||
typename std::iterator_traits<in_iter>::iterator_category, | |||||
std::input_iterator_tag>::value>::type> | |||||
void append(in_iter in_start, in_iter in_end) { | void append(in_iter in_start, in_iter in_end) { | ||||
size_type num_inputs = std::distance(in_start, in_end); | size_type num_inputs = std::distance(in_start, in_end); | ||||
// Grow allocated space if needed. | // Grow allocated space if needed. | ||||
@@ -432,10 +409,11 @@ public: | |||||
std::uninitialized_fill(this->begin(), this->end(), elm); | std::uninitialized_fill(this->begin(), this->end(), elm); | ||||
} | } | ||||
template <typename in_iter, | |||||
typename = typename std::enable_if<std::is_convertible< | |||||
typename std::iterator_traits<in_iter>::iterator_category, | |||||
std::input_iterator_tag>::value>::type> | |||||
template < | |||||
typename in_iter, | |||||
typename = typename std::enable_if<std::is_convertible< | |||||
typename std::iterator_traits<in_iter>::iterator_category, | |||||
std::input_iterator_tag>::value>::type> | |||||
void assign(in_iter in_start, in_iter in_end) { | void assign(in_iter in_start, in_iter in_end) { | ||||
clear(); | clear(); | ||||
append(in_start, in_end); | append(in_start, in_end); | ||||
@@ -571,8 +549,7 @@ public: | |||||
std::fill_n(it, num_overwritten, elm); | std::fill_n(it, num_overwritten, elm); | ||||
// Insert the non-overwritten middle part. | // Insert the non-overwritten middle part. | ||||
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, | |||||
elm); | |||||
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm); | |||||
return it; | return it; | ||||
} | } | ||||
@@ -646,8 +623,7 @@ public: | |||||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | ||||
this->grow(); | this->grow(); | ||||
} | } | ||||
new (static_cast<void*>(this->end())) | |||||
T(std::forward<ArgTypes>(args)...); | |||||
new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...); | |||||
this->set_end(this->end() + 1); | this->set_end(this->end() + 1); | ||||
} | } | ||||
@@ -661,13 +637,11 @@ public: | |||||
return std::equal(this->begin(), this->end(), rhs.begin()); | return std::equal(this->begin(), this->end(), rhs.begin()); | ||||
} | } | ||||
bool operator!=(const SmallVectorImpl<T>& rhs) const { | |||||
return !(*this == rhs); | |||||
} | |||||
bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); } | |||||
bool operator<(const SmallVectorImpl<T>& rhs) const { | bool operator<(const SmallVectorImpl<T>& rhs) const { | ||||
return std::lexicographical_compare(this->begin(), this->end(), | |||||
rhs.begin(), rhs.end()); | |||||
return std::lexicographical_compare( | |||||
this->begin(), this->end(), rhs.begin(), rhs.end()); | |||||
} | } | ||||
}; | }; | ||||
@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||||
// Copy over the extra elms. | // Copy over the extra elms. | ||||
if (this->size() > rhs.size()) { | if (this->size() > rhs.size()) { | ||||
size_t elm_diff = this->size() - rhs.size(); | size_t elm_diff = this->size() - rhs.size(); | ||||
this->uninitialized_move(this->begin() + num_shared, this->end(), | |||||
rhs.end()); | |||||
this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end()); | |||||
rhs.set_end(rhs.end() + elm_diff); | rhs.set_end(rhs.end() + elm_diff); | ||||
this->destroy_range(this->begin() + num_shared, this->end()); | this->destroy_range(this->begin() + num_shared, this->end()); | ||||
this->set_end(this->begin() + num_shared); | this->set_end(this->begin() + num_shared); | ||||
} else if (rhs.size() > this->size()) { | } else if (rhs.size() > this->size()) { | ||||
size_t elm_diff = rhs.size() - this->size(); | size_t elm_diff = rhs.size() - this->size(); | ||||
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), | |||||
this->end()); | |||||
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end()); | |||||
this->set_end(this->end() + elm_diff); | this->set_end(this->end() + elm_diff); | ||||
this->destroy_range(rhs.begin() + num_shared, rhs.end()); | this->destroy_range(rhs.begin() + num_shared, rhs.end()); | ||||
rhs.set_end(rhs.begin() + num_shared); | rhs.set_end(rhs.begin() + num_shared); | ||||
@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||||
const SmallVectorImpl<T>& rhs) { | |||||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) { | |||||
if (this == &rhs) | if (this == &rhs) | ||||
return *this; | return *this; | ||||
size_t rhs_sz = rhs.size(); | size_t rhs_sz = rhs.size(); | ||||
@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||||
} else if (cur_sz) { | } else if (cur_sz) { | ||||
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | ||||
} | } | ||||
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), | |||||
this->begin() + cur_sz); | |||||
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||||
this->set_end(this->begin() + rhs_sz); | this->set_end(this->begin() + rhs_sz); | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { | |||||
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | ||||
} | } | ||||
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), | |||||
this->begin() + cur_sz); | |||||
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||||
this->set_end(this->begin() + rhs_sz); | this->set_end(this->begin() + rhs_sz); | ||||
@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> { | |||||
public: | public: | ||||
SmallVector() : SmallVectorImpl<T>(N) {} | SmallVector() : SmallVectorImpl<T>(N) {} | ||||
explicit SmallVector(size_t size, const T& value = T()) | |||||
: SmallVectorImpl<T>(N) { | |||||
explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) { | |||||
this->assign(size, value); | this->assign(size, value); | ||||
} | } | ||||
@@ -901,15 +869,13 @@ namespace std { | |||||
/// Implement std::swap in terms of SmallVector swap. | /// Implement std::swap in terms of SmallVector swap. | ||||
template <typename T> | template <typename T> | ||||
inline void swap(megdnn::SmallVectorImpl<T>& lhs, | |||||
megdnn::SmallVectorImpl<T>& rhs) { | |||||
inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) { | |||||
lhs.swap(rhs); | lhs.swap(rhs); | ||||
} | } | ||||
/// Implement std::swap in terms of SmallVector swap. | /// Implement std::swap in terms of SmallVector swap. | ||||
template <typename T, unsigned N> | template <typename T, unsigned N> | ||||
inline void swap(megdnn::SmallVector<T, N>& lhs, | |||||
megdnn::SmallVector<T, N>& rhs) { | |||||
inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) { | |||||
lhs.swap(rhs); | lhs.swap(rhs); | ||||
} | } | ||||
} // end namespace std | } // end namespace std | ||||
@@ -17,13 +17,13 @@ | |||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
struct Version { | |||||
int major, minor, patch; | |||||
}; | |||||
struct Version { | |||||
int major, minor, patch; | |||||
}; | |||||
//! get megdnn version of the binary | |||||
Version get_version(); | |||||
} | |||||
//! get megdnn version of the binary | |||||
Version get_version(); | |||||
} // namespace megdnn | |||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
@@ -22,18 +22,17 @@ using namespace aarch64; | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | ||||
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.filter_meta.format == param::Convolution::Format::NCHW && | return param.filter_meta.format == param::Convolution::Format::NCHW && | ||||
param.src_type.enumv() == DTypeEnum::Float16 && | param.src_type.enumv() == DTypeEnum::Float16 && | ||||
param.filter_type.enumv() == DTypeEnum::Float16 && | param.filter_type.enumv() == DTypeEnum::Float16 && | ||||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && | |||||
param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | (FH == 2 || FH == 3 || FH == 5 || FH == 7); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||||
return 0; | return 0; | ||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
return {}; | return {}; | ||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | bool large_group = group >= param.nr_threads; | ||||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||||
size_t, size_t, size_t, size_t, size_t)>; | |||||
using Func = std::function<void( | |||||
const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t, | |||||
size_t)>; | |||||
Func conv = nullptr; | Func conv = nullptr; | ||||
if (FH == 2) { | if (FH == 2) { | ||||
conv = fp16::conv_stride2::do_conv_2x2_stride2; | conv = fp16::conv_stride2::do_conv_2x2_stride2; | ||||
@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
for (size_t ic = 0; ic < IC; ic++) { | for (size_t ic = 0; ic < IC; ic++) { | ||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | } | ||||
for (size_t oc = 0; oc < OC; oc++) { | for (size_t oc = 0; oc < OC; oc++) { | ||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | } | ||||
}; | }; | ||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ||||
} else { | } else { | ||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
auto copy_padding = [bundle]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | }; | ||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | ret_kerns.push_back({copy_padding, {group, N, IC}}); | ||||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
auto do_conv = [bundle, conv]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
ncb_index.ndrange_id); | |||||
do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||||
}; | }; | ||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | ret_kerns.push_back({do_conv, {group, N, OC}}); | ||||
} | } | ||||
@@ -18,13 +18,13 @@ namespace aarch64 { | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "ARMV8F16STRD2"; } | const char* name() const override { return "ARMV8F16STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
bool usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
@@ -20,9 +20,9 @@ namespace aarch64 { | |||||
namespace fp16 { | namespace fp16 { | ||||
namespace conv_stride2 { | namespace conv_stride2 { | ||||
static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_2x2_stride2( | |||||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
size_t mod4_left = width & 3; | size_t mod4_left = width & 3; | ||||
@@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
"5: \n" | "5: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | ||||
: "r"(mod4_left), "w"(_k0123) | : "r"(mod4_left), "w"(_k0123) | ||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||||
"v31"); | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
} | } | ||||
} | } | ||||
static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_3x3_stride2( | |||||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
size_t mod3_left = width % 3; | size_t mod3_left = width % 3; | ||||
@@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
"3: \n" | "3: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | ||||
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | ||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29"); | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
} | } | ||||
} | } | ||||
static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_5x5_stride2( | |||||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
size_t mod2_left = width & 1; | size_t mod2_left = width & 1; | ||||
@@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
const __fp16* r4 = src_ptr + IW * 4; | const __fp16* r4 = src_ptr + IW * 4; | ||||
register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | ||||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||||
MEGDNN_SIMD_LOADU(filter + 4); | |||||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||||
MEGDNN_SIMD_LOADU(filter + 8); | |||||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||||
MEGDNN_SIMD_LOADU(filter + 12); | |||||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||||
MEGDNN_SIMD_LOADU(filter + 16); | |||||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||||
MEGDNN_SIMD_LOADU(filter + 20); | |||||
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = | |||||
MEGDNN_SIMD_SET1(filter[24]); | |||||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||||
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]); | |||||
for (size_t i = 0; i < OH; i++) { | for (size_t i = 0; i < OH; i++) { | ||||
asm volatile( | asm volatile( | ||||
@@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
"bne 2b \n" | "bne 2b \n" | ||||
"3: \n" | "3: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||||
"+r"(r3), "+r"(r4) | |||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
"+r"(r4) | |||||
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | ||||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||||
"r"(mod2_left) | |||||
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31"); | |||||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||||
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
} | } | ||||
} | } | ||||
static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_7x7_stride2( | |||||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
@@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
const __fp16* r6 = src_ptr + IW * 6; | const __fp16* r6 = src_ptr + IW * 6; | ||||
register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | ||||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||||
MEGDNN_SIMD_LOADU(filter + 4); | |||||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||||
MEGDNN_SIMD_LOADU(filter + 8); | |||||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||||
MEGDNN_SIMD_LOADU(filter + 12); | |||||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||||
MEGDNN_SIMD_LOADU(filter + 16); | |||||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||||
MEGDNN_SIMD_LOADU(filter + 20); | |||||
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = | |||||
MEGDNN_SIMD_LOADU(filter + 24); | |||||
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = | |||||
MEGDNN_SIMD_LOADU(filter + 28); | |||||
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = | |||||
MEGDNN_SIMD_LOADU(filter + 32); | |||||
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = | |||||
MEGDNN_SIMD_LOADU(filter + 36); | |||||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||||
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24); | |||||
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28); | |||||
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32); | |||||
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36); | |||||
register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | ||||
MEGDNN_SIMD_LOADU(filter + 40); | MEGDNN_SIMD_LOADU(filter + 40); | ||||
register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | ||||
MEGDNN_SIMD_LOADU(filter + 44); | MEGDNN_SIMD_LOADU(filter + 44); | ||||
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = | |||||
MEGDNN_SIMD_SET1(filter[48]); | |||||
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]); | |||||
for (size_t i = 0; i < OH; i++) { | for (size_t i = 0; i < OH; i++) { | ||||
asm volatile( | asm volatile( | ||||
@@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
"bne 2b \n" | "bne 2b \n" | ||||
"3: \n" | "3: \n" | ||||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
"+r"(r4), "+r"(r5), "+r"(r6) | |||||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||||
"+r"(r5), "+r"(r6) | |||||
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | ||||
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | ||||
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | ||||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||||
"w"(_k48484848) | |||||
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", | |||||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
"v26", "v27", "v28", "v29", "v30", "v31"); | |||||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||||
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -21,18 +21,17 @@ using namespace megdnn; | |||||
using namespace aarch64; | using namespace aarch64; | ||||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | ||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy) const { | |||||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | return param.filter_meta.format == param::ConvBias::Format::NCHW && | ||||
param.src_type.enumv() == DTypeEnum::Float32 && | param.src_type.enumv() == DTypeEnum::Float32 && | ||||
param.filter_type.enumv() == DTypeEnum::Float32 && | param.filter_type.enumv() == DTypeEnum::Float32 && | ||||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
FH == fm.spatial[1] && | |||||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | (FH == 2 || FH == 3 || FH == 5 || FH == 7); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | return 0; | ||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
@@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
return {}; | return {}; | ||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | |||||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
@@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
size_t group = fm.group; | size_t group = fm.group; | ||||
bool large_group = group >= param.nr_threads; | bool large_group = group >= param.nr_threads; | ||||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||||
size_t, size_t, size_t, size_t)>; | |||||
using Func = std::function<void( | |||||
const float*, const float*, float*, size_t, size_t, size_t, size_t, | |||||
size_t)>; | |||||
Func conv = nullptr; | Func conv = nullptr; | ||||
if (FH == 2) { | if (FH == 2) { | ||||
conv = fp32::conv_stride2::do_conv_2x2_stride2; | conv = fp32::conv_stride2::do_conv_2x2_stride2; | ||||
@@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
conv = fp32::conv_stride2::do_conv_7x7_stride2; | conv = fp32::conv_stride2::do_conv_7x7_stride2; | ||||
} | } | ||||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||||
float, float>::get_bundle_stride(param, large_group); | |||||
WorkspaceBundle bundle = | |||||
arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
param, large_group); | |||||
SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
//! Dense conv and small group | //! Dense conv and small group | ||||
@@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
for (size_t ic = 0; ic < IC; ic++) { | for (size_t ic = 0; ic < IC; ic++) { | ||||
arm_common::MultithreadDirectConvCommon<float, float>:: | arm_common::MultithreadDirectConvCommon<float, float>:: | ||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, | |||||
{ncb_index.thread_id, 0, ic}); | |||||
} | } | ||||
for (size_t oc = 0; oc < OC; oc++) { | for (size_t oc = 0; oc < OC; oc++) { | ||||
arm_common::MultithreadDirectConvCommon< | |||||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
ncb_index, conv, | |||||
{ncb_index.thread_id, | |||||
0, oc}); | |||||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||||
do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv, | |||||
{ncb_index.thread_id, 0, oc}); | |||||
} | } | ||||
}; | }; | ||||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ||||
} else { | } else { | ||||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
auto copy_padding = [bundle]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
arm_common::MultithreadDirectConvCommon<float, float>:: | arm_common::MultithreadDirectConvCommon<float, float>:: | ||||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
ncb_index.ndrange_id); | |||||
copy_padding_kern_stride( | |||||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
}; | }; | ||||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | ret_kerns.push_back({copy_padding, {group, N, IC}}); | ||||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
auto do_conv = [bundle, conv]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
arm_common::MultithreadDirectConvCommon< | |||||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
ncb_index, conv, | |||||
ncb_index.ndrange_id); | |||||
arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||||
}; | }; | ||||
ret_kerns.push_back({do_conv, {group, N, OC}}); | ret_kerns.push_back({do_conv, {group, N, OC}}); | ||||
} | } | ||||
@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "ARMV8F32STRD2"; } | const char* name() const override { return "ARMV8F32STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
bool usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
@@ -16,16 +16,15 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace fp32{ | |||||
namespace fp32 { | |||||
namespace conv_stride2 { | namespace conv_stride2 { | ||||
//! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | ||||
// refer to function do_conv_2x2_stride2_asm_unroll4 | // refer to function do_conv_2x2_stride2_asm_unroll4 | ||||
static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_2x2_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
size_t mod4_left = width & 3; | size_t mod4_left = width & 3; | ||||
@@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
"5: \n" | "5: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | ||||
: "r"(mod4_left), "w"(_k0123) | : "r"(mod4_left), "w"(_k0123) | ||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||||
"v31"); | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
} | } | ||||
// refer to function do_conv_3x3_stride2_asm_unroll3 | // refer to function do_conv_3x3_stride2_asm_unroll3 | ||||
static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_3x3_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
size_t mod3_left = width % 3; | size_t mod3_left = width % 3; | ||||
@@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
"ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | ||||
"ld2 {v5.4s, v6.4s}, [%3], #32 \n" | "ld2 {v5.4s, v6.4s}, [%3], #32 \n" | ||||
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||||
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||||
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | ||||
"ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | ||||
"fmla v0.4s, v2.4s, v22.4s \n" | "fmla v0.4s, v2.4s, v22.4s \n" | ||||
@@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
"3: \n" | "3: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | ||||
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | ||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29"); | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
} | } | ||||
// refer to function do_conv_5x5_stride2_asm_unroll2 | // refer to function do_conv_5x5_stride2_asm_unroll2 | ||||
static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_5x5_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
size_t mod2_left = width & 1; | size_t mod2_left = width & 1; | ||||
@@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
"bne 2b \n" | "bne 2b \n" | ||||
"3: \n" | "3: \n" | ||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||||
"+r"(r3), "+r"(r4) | |||||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
"+r"(r4) | |||||
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | ||||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||||
"r"(mod2_left) | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24"); | |||||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
} | } | ||||
// refer to function do_conv_7x7_stride2_asm_unroll2 | // refer to function do_conv_7x7_stride2_asm_unroll2 | ||||
static void do_conv_7x7_stride2(const float* src, const float* filter, | |||||
float* dst, size_t IH, size_t IW, size_t OH, | |||||
size_t OW, size_t IC) { | |||||
static void do_conv_7x7_stride2( | |||||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
size_t OH, size_t OW, size_t IC) { | |||||
const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
@@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter, | |||||
"bne 2b \n" | "bne 2b \n" | ||||
"3: \n" | "3: \n" | ||||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
"+r"(r4), "+r"(r5), "+r"(r6) | |||||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||||
"+r"(r5), "+r"(r6) | |||||
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | ||||
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | ||||
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | ||||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||||
"w"(_k48484848) | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18"); | |||||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18"); | |||||
r0 += tail_step; | r0 += tail_step; | ||||
r1 += tail_step; | r1 += tail_step; | ||||
@@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
size_t N = OH * OW; | size_t N = OH * OW; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
@@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
} | } | ||||
#else | #else | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
@@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
return {nullptr, {part0, part1, part2}}; | return {nullptr, {part0, part1, part2}}; | ||||
} | } | ||||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl( | |||||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
auto is_xcorr = !param.filter_meta.should_flip; | auto is_xcorr = !param.filter_meta.should_flip; | ||||
UNPACK_CONV_NCB_KERN_SIZES(param); | UNPACK_CONV_NCB_KERN_SIZES(param); | ||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
@@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | ||||
} else { | } else { | ||||
if (is_xcorr) | if (is_xcorr) | ||||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
img2col_stride<true>( | |||||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
else | else | ||||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
img2col_stride<false>( | |||||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
} | } | ||||
} | } | ||||
{ | { | ||||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
bundle.get_size(2)); | |||||
Workspace workspace( | |||||
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||||
size_t M = OC; | size_t M = OC; | ||||
size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
size_t N = OH * OW; | size_t N = OH * OW; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | ||||
if (cpuinfo_has_arm_neon_dot()) { | if (cpuinfo_has_arm_neon_dot()) { | ||||
@@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
} | } | ||||
#else | #else | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
bias); \ | |||||
} \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
#endif | #endif | ||||
@@ -12,8 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/aarch64/conv_bias/opr_impl.h" | #include "src/aarch64/conv_bias/opr_impl.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "S8MATMUL"; } | const char* name() const override { return "S8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
bool usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override { | size_t get_workspace(const NCBKernSizeParam& param) const override { | ||||
return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
} | } | ||||
SmallVector<NCBKern> dispatch_kerns( | |||||
const NCBKernSizeParam& param) const override { | |||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||||
size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
} | } | ||||
@@ -29,9 +29,10 @@ struct KernCaller; | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
struct KernCaller<bmode, Op, 8, 12> { | struct KernCaller<bmode, Op, 8, 12> { | ||||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
static void run( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace) { | |||||
megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
@@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||||
is_first_k); | |||||
matmul_8x12x4::kern_8x12( | |||||
packA, cur_packB, K, workspace, 12, is_first_k); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||||
12>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>:: | |||||
postprocess(bias, workspace, output, LDC, op); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_8x12x4::kern_8x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | ||||
@@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
matmul_8x12x4::kern_4x12( | |||||
packA, cur_packB, K, workspace, 12, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
@@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x12x4::kern_4x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#undef cb | #undef cb | ||||
output += 4; | output += 4; | ||||
@@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
struct KernCaller<bmode, Op, 4, 4> { | struct KernCaller<bmode, Op, 4, 4> { | ||||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
static void run( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace) { | |||||
megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
constexpr size_t A_INTERLEAVE = 4; | constexpr size_t A_INTERLEAVE = 4; | ||||
@@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||||
4>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess( | |||||
bias, workspace, output, LDC, op); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||||
4, is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_4x4x16::kern_4x4_remain( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
@@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#undef cb | #undef cb | ||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
@@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | ||||
void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_s8_4x4_nobias_identity::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_s8_4x4_nobias_identity::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
} else { | } else { | ||||
@@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | ||||
void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_s8_8x12_nobias_identity::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | ||||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | ||||
if (transpose) { | if (transpose) { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_s8_8x12_nobias_identity::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
} else { | } else { | ||||
@@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||||
#endif | #endif | ||||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||||
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||||
const dt_int32* bias, dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
workspace); \ | |||||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ | |||||
dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||||
dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ | |||||
} | } | ||||
#define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
@@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
#endif | #endif | ||||
#undef DEFINE_OP | #undef DEFINE_OP | ||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||||
scale_A* scale_B, scale_C); | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_qint8> op( \ | |||||
scale_A* scale_B, scale_A* scale_B, scale_C); | |||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
#endif | #endif | ||||
#undef DEFINE_OP | #undef DEFINE_OP | ||||
@@ -20,43 +20,42 @@ namespace matmul { | |||||
* | * | ||||
* \name gemm_<type>_<block>_biasmode_nolinemode | * \name gemm_<type>_<block>_biasmode_nolinemode | ||||
*/ | */ | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||||
false, true, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||||
gemm_s8_4x4_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||||
false, true, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||||
gemm_s8_8x12_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); | |||||
#endif | #endif | ||||
} // namespace matmul | } // namespace matmul | ||||
@@ -13,13 +13,13 @@ | |||||
#include "src/aarch64/conv_bias/int8/algos.h" | #include "src/aarch64/conv_bias/int8/algos.h" | ||||
#include "src/aarch64/conv_bias/quint8/algos.h" | #include "src/aarch64/conv_bias/quint8/algos.h" | ||||
#include "src/naive/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/fallback/convolution/opr_impl.h" | |||||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||||
#include "src/aarch64/conv_bias/fp16/algos.h" | #include "src/aarch64/conv_bias/fp16/algos.h" | ||||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||||
#include "src/fallback/convolution/opr_impl.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace aarch64; | using namespace aarch64; | ||||
@@ -56,12 +56,10 @@ public: | |||||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | ||||
return m_direct_algos; | return m_direct_algos; | ||||
} | } | ||||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||||
const { | |||||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const { | |||||
return m_matmul_algos; | return m_matmul_algos; | ||||
} | } | ||||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | ||||
}; | }; | ||||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | ||||
@@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | ||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||||
ConvBiasImpl::get_all_packed_algo() { | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||||
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | ||||
algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||||
algo_pack().direct_algos().end()); | |||||
algos.insert( | |||||
algos.begin(), algo_pack().direct_algos().begin(), | |||||
algo_pack().direct_algos().end()); | |||||
//! We put matmul algos at the begin. Because matmul will get privilege when | //! We put matmul algos at the begin. Because matmul will get privilege when | ||||
//! prefer return true. See | //! prefer return true. See | ||||
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||||
algo_pack().matmul_algos().end()); | |||||
algos.insert( | |||||
algos.begin(), algo_pack().matmul_algos().begin(), | |||||
algo_pack().matmul_algos().end()); | |||||
return std::move(algos); | return std::move(algos); | ||||
} | } | ||||
@@ -9,8 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/common/utils.h" | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/common/utils.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
@@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
size_t N = OH * OW; | size_t N = OH * OW; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
@@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | ||||
} | } | ||||
#else | #else | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
@@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
return {nullptr, {part0, part1, part2}}; | return {nullptr, {part0, part1, part2}}; | ||||
} | } | ||||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( | |||||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
auto is_xcorr = !param.filter_meta.should_flip; | auto is_xcorr = !param.filter_meta.should_flip; | ||||
UNPACK_CONV_NCB_KERN_SIZES(param); | UNPACK_CONV_NCB_KERN_SIZES(param); | ||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
@@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | ||||
} else { | } else { | ||||
if (is_xcorr) | if (is_xcorr) | ||||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
img2col_stride<true>( | |||||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
else | else | ||||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
FW, SH, SW); | |||||
img2col_stride<false>( | |||||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
} | } | ||||
} | } | ||||
{ | { | ||||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
bundle.get_size(2)); | |||||
Workspace workspace( | |||||
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||||
size_t M = OC; | size_t M = OC; | ||||
size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
size_t N = OH * OW; | size_t N = OH * OW; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | ||||
if (cpuinfo_has_arm_neon_dot()) { | if (cpuinfo_has_arm_neon_dot()) { | ||||
@@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | ||||
} | } | ||||
#else | #else | ||||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
_bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved< \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
bias); \ | |||||
} \ | |||||
#define DISPATCH_GEMM_STRATEGY( \ | |||||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
_nonline_midout_enum) \ | |||||
MIDOUT_BEGIN( \ | |||||
megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||||
_bias_midout_enum, _nonline_midout_enum) { \ | |||||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | ||||
@@ -12,8 +12,8 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/aarch64/conv_bias/opr_impl.h" | #include "src/aarch64/conv_bias/opr_impl.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "QU8MATMUL"; } | const char* name() const override { return "QU8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
bool usable( | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override { | size_t get_workspace(const NCBKernSizeParam& param) const override { | ||||
return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
} | } | ||||
SmallVector<NCBKern> dispatch_kerns( | |||||
const NCBKernSizeParam& param) const override { | |||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||||
size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
} | } | ||||
@@ -14,8 +14,8 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | ||||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | #include "src/arm_common/conv_bias/matmul_postprocess.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -29,10 +29,10 @@ struct KernCaller; | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
struct KernCaller<bmode, Op, 8, 8, true> { | struct KernCaller<bmode, Op, 8, 8, true> { | ||||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
static void run( | |||||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||||
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
constexpr size_t B_INTERLEAVE = 8; | constexpr size_t B_INTERLEAVE = 8; | ||||
@@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, zp_A, zp_B, zAB); | |||||
matmul_8x8x4::kern_8x8( | |||||
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
8>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||||
postprocess(bias, workspace, output, LDC, op); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4), | |||||
zp_A, zp_B, zAB); | |||||
matmul_8x8x4::kern_8x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B, zAB); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
@@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
zp_A, zp_B, zAB); | |||||
matmul_8x8x4::kern_4x8( | |||||
packA, cur_packB, K, workspace, 8, is_first_k, | |||||
std::min<size_t>(M - m, 4), zp_A, zp_B, zAB); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
@@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B, | |||||
zAB); | |||||
matmul_8x8x4::kern_4x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||||
zp_B, zAB); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#undef cb | #undef cb | ||||
output += 4; | output += 4; | ||||
@@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
struct KernCaller<bmode, Op, 8, 8, false> { | struct KernCaller<bmode, Op, 8, 8, false> { | ||||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
static void run( | |||||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||||
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
@@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, zp_A, zp_B); | |||||
matmul_8x8x8::kern_8x8( | |||||
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
8>::postprocess(bias, workspace, | |||||
output, LDC, op); | |||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||||
postprocess(bias, workspace, output, LDC, op); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(N - n, 4), | |||||
zp_A, zp_B); | |||||
matmul_8x8x8::kern_8x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | ||||
#undef cb | #undef cb | ||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
zp_A, zp_B); | |||||
matmul_8x8x8::kern_4x8( | |||||
packA, cur_packB, K, workspace, 8, is_first_k, | |||||
std::min<size_t>(M - m, 4), zp_A, zp_B); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
@@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||||
matmul_8x8x8::kern_4x4( | |||||
packA, cur_packB, K, workspace, 4, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||||
zp_B); | |||||
#define cb(m, n) \ | #define cb(m, n) \ | ||||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | ||||
bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
#undef cb | #undef cb | ||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | ||||
void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_u8_8x8_dot_nobias_identity::pack_A( | |||||
uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax); | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||||
y0, ymax, k0, kmax); | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_u8_8x8_dot_nobias_identity::pack_B( | |||||
uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
@@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { | |||||
#endif | #endif | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | ||||
void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, | |||||
const dt_uint8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_u8_8x8_nodot_nobias_identity::pack_A( | |||||
dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax, zA); | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax, zA); | |||||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||||
} | } | ||||
} | } | ||||
void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool transpose) const { | |||||
void gemm_u8_8x8_nodot_nobias_identity::pack_B( | |||||
dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
k0, kmax, zB); | |||||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( | |||||
out, in, ldin, x0, xmax, k0, kmax, zB); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||||
zB); | |||||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); | |||||
} | } | ||||
} | } | ||||
@@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { | |||||
return 8 * 8 * sizeof(dt_int32); | return 8 * 8 * sizeof(dt_int32); | ||||
} | } | ||||
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ | |||||
_OP) \ | |||||
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ | |||||
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ | |||||
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||||
const dt_int32* bias, dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
workspace, zp_A, zp_B); \ | |||||
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \ | |||||
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \ | |||||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||||
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||||
dt_int32* workspace) const { \ | |||||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
DEFINE_OP(_OP); \ | |||||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ | |||||
zp_B); \ | |||||
} | } | ||||
#define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
@@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | ||||
#undef DEFINE_OP | #undef DEFINE_OP | ||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||||
scale_A* scale_B, scale_C, zp_C); | |||||
#define DEFINE_OP(_Op) \ | |||||
arm_common::_Op<dt_qint32, dt_quint8> op( \ | |||||
scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||||
FuseAddReluOp) | |||||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
#endif | #endif | ||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, | |||||
AddOp) | |||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||||
FuseAddReluOp) | |||||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
FuseAddHSwishOp) | |||||
#undef DEFINE_OP | #undef DEFINE_OP | ||||
#undef KERN | #undef KERN | ||||
@@ -16,46 +16,44 @@ namespace aarch64 { | |||||
namespace matmul { | namespace matmul { | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||||
false, true, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, | |||||
gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity); | |||||
#endif | #endif | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||||
false, true, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, | |||||
gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -11,11 +11,11 @@ | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/aarch64/handle.h" | #include "src/aarch64/handle.h" | ||||
#include "src/aarch64/matrix_mul/opr_impl.h" | #include "src/aarch64/matrix_mul/opr_impl.h" | ||||
#include "src/aarch64/rotate/opr_impl.h" | |||||
#include "src/aarch64/relayout/opr_impl.h" | #include "src/aarch64/relayout/opr_impl.h" | ||||
#include "src/aarch64/conv_bias/opr_impl.h" | |||||
#include "src/aarch64/rotate/opr_impl.h" | |||||
#include "src/aarch64/warp_perspective/opr_impl.h" | #include "src/aarch64/warp_perspective/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | ||||
#pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -14,20 +14,18 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
class HandleImpl: public arm_common::HandleImpl { | |||||
public: | |||||
HandleImpl(megcoreComputingHandle_t computing_handle, | |||||
HandleType type = HandleType::AARCH64): | |||||
arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||||
{} | |||||
class HandleImpl : public arm_common::HandleImpl { | |||||
public: | |||||
HandleImpl( | |||||
megcoreComputingHandle_t computing_handle, | |||||
HandleType type = HandleType::AARCH64) | |||||
: arm_common::HandleImpl::HandleImpl(computing_handle, type) {} | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
}; | }; | ||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -21,9 +21,7 @@ namespace aarch64 { | |||||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F32K8X12X1"; } | const char* name() const override { return "AARCH64_F32K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -35,8 +33,7 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -48,9 +45,7 @@ public: | |||||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F32K4X16X1"; } | const char* name() const override { return "AARCH64_F32K4X16X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -61,9 +56,7 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -73,8 +66,7 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32Gemv final | |||||
: public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
public: | public: | ||||
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | ||||
m_handle_type = Handle::HandleType::AARCH64; | m_handle_type = Handle::HandleType::AARCH64; | ||||
@@ -85,9 +77,7 @@ public: | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F16_K8X24X1"; } | const char* name() const override { return "AARCH64_F16_K8X24X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -98,9 +88,7 @@ public: | |||||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -115,12 +103,8 @@ public: | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
@@ -130,12 +114,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
@@ -147,8 +127,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -163,9 +142,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -178,9 +155,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -192,9 +167,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -207,9 +180,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -222,8 +193,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -238,12 +208,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -257,12 +224,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -276,8 +240,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -292,9 +255,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -306,9 +267,7 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -321,12 +280,8 @@ public: | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | |||||
return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
@@ -336,8 +291,7 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -352,9 +306,7 @@ public: | |||||
#endif | #endif | ||||
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -16,11 +16,11 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||||
true, hgemm_8x24); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||||
false, true, gemm_nopack_f16_8x8); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -9,8 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -21,8 +21,9 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
void kern_8x1( | |||||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
LDB *= sizeof(dt_float16); | LDB *= sizeof(dt_float16); | ||||
asm volatile( | asm volatile( | ||||
".arch armv8.2-a+fp16\n" | ".arch armv8.2-a+fp16\n" | ||||
@@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
"memory"); | |||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
// |v23[0-7]| |v27[0-7]| | // |v23[0-7]| |v27[0-7]| | ||||
// +--------+ +--------+ | // +--------+ +--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
void kern_8x4( | |||||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
//! LDB means number of elements in one block in B. we will read 24 numbers | //! LDB means number of elements in one block in B. we will read 24 numbers | ||||
//! first. so minus 24 * 2 bytes here. | //! first. so minus 24 * 2 bytes here. | ||||
LDB = (LDB - 24) * sizeof(dt_float16); | LDB = (LDB - 24) * sizeof(dt_float16); | ||||
@@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
// | v7[0-7]| |v31[0-7]| | // | v7[0-7]| |v31[0-7]| | ||||
// +--------+ +--------+ | // +--------+ +--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
void kern_8x8( | |||||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
dt_float16* output) { | |||||
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 32) * sizeof(dt_float16); | LDB = (LDB - 32) * sizeof(dt_float16); | ||||
@@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29", | |||||
"v30", "v31", "cc", "memory"); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | ||||
void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||||
const dt_float16* B, size_t LDB, dt_float16* C, | |||||
size_t LDC, size_t M, size_t K, size_t N, | |||||
const dt_float16*, void*, bool trA, | |||||
bool trB) const { | |||||
void gemm_nopack_f16_8x8::kern( | |||||
const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, | |||||
size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, | |||||
bool trB) const { | |||||
constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
@@ -17,21 +17,23 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||||
size_t K, size_t LDA, const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packA_n( | |||||
const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||||
const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||||
size_t K, size_t LDA, const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packA_t( | |||||
const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||||
const float* alpha); | |||||
MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||||
size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_packB_n( | |||||
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||||
size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_packB_t( | |||||
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||||
MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||||
size_t LDC, size_t M, size_t N, size_t K, | |||||
int type, const float* beta); | |||||
MEGDNN_NOINLINE void sgemm_kernel12x8( | |||||
const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N, | |||||
size_t K, int type, const float* beta); | |||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -12,7 +12,6 @@ | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul_general_4x16 { | namespace matmul_general_4x16 { | ||||
@@ -39,8 +38,9 @@ namespace matmul_general_4x16 { | |||||
// +--+ - - - - +--------+--------+--------+--------+ | // +--+ - - - - +--------+--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
void kern_4x16(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, int m_remain) { | |||||
void kern_4x16( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||||
"x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||||
// +--+--+ - - - - +--------+ | // +--+--+ - - - - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
int LDC, bool is_first_k, int m_remain, int n_remain) { | |||||
void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
STORE_LINE("6", "2") \ | STORE_LINE("6", "2") \ | ||||
STORE_LINE("7", "3") \ | STORE_LINE("7", "3") \ | ||||
"105:\n" | "105:\n" | ||||
// clang-format on | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" LOAD_C | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v4.16b, v4.16b, v4.16b\n" | |||||
"eor v5.16b, v5.16b, v5.16b\n" | |||||
"eor v6.16b, v6.16b, v6.16b\n" | |||||
"eor v7.16b, v7.16b, v7.16b\n" | |||||
"2: \n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"6:\n" STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||||
"x2", "x3", "x10", "cc", "memory"); | |||||
// clang-format on | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add x1, x0, %x[LDC]\n" | |||||
"add x2, x1, %x[LDC]\n" | |||||
"add x3, x2, %x[LDC]\n" | |||||
"cmp %w[is_first_k], #1\n" | |||||
"beq 1f\n" LOAD_C | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"eor v4.16b, v4.16b, v4.16b\n" | |||||
"eor v5.16b, v5.16b, v5.16b\n" | |||||
"eor v6.16b, v6.16b, v6.16b\n" | |||||
"eor v7.16b, v7.16b, v7.16b\n" | |||||
"2: \n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
"6:\n" STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), | |||||
[outptr] "+r"(outptr) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10", | |||||
"cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
int ymax, int k0, int kmax) { | |||||
void sgemm_4x16_pack_A_n( | |||||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
float zerobuff[4]; | float zerobuff[4]; | ||||
std::memset(zerobuff, 0, sizeof(float) * 4); | std::memset(zerobuff, 0, sizeof(float) * 4); | ||||
constexpr int PACK_SIZE = 4*4; | |||||
constexpr int PACK_SIZE = 4 * 4; | |||||
int y = y0; | int y = y0; | ||||
for (; y + 3 < ymax; y += 4) { | for (; y + 3 < ymax; y += 4) { | ||||
// printf("main loop pack_a_n %p \n",outptr); | |||||
// printf("main loop pack_a_n %p \n",outptr); | |||||
const float* inptr0 = inptr + y * ldin + k0; | const float* inptr0 = inptr + y * ldin + k0; | ||||
const float* inptr1 = inptr0 + ldin; | const float* inptr1 = inptr0 + ldin; | ||||
const float* inptr2 = inptr1 + ldin; | const float* inptr2 = inptr1 + ldin; | ||||
@@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
/* Everything falls through in here */ | /* Everything falls through in here */ | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax) { | |||||
void sgemm_4x16_pack_A_t( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
float* outptr_base = out; | float* outptr_base = out; | ||||
@@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
void sgemm_4x16_pack_B_n( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
int ksize16 = ksize * 16; | int ksize16 = ksize * 16; | ||||
int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
@@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
for (; x + 16 <= xmax; x += 16) { | for (; x + 16 <= xmax; x += 16) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize16; | outptr += ksize16; | ||||
} | } | ||||
outptr = outptr_base4; | outptr = outptr_base4; | ||||
for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
void sgemm_4x16_pack_B_t( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||||
float* outptr = out; | float* outptr = out; | ||||
const float* inptr = in; | const float* inptr = in; | ||||
float zerobuff[4]; | float zerobuff[4]; | ||||
@@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
int x = (kmax - k0); | int x = (kmax - k0); | ||||
for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||||
64); | |||||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64); | |||||
outptr_inner += 64; | outptr_inner += 64; | ||||
} | } | ||||
for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
@@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
/* Everything falls through in here */ | /* Everything falls through in here */ | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
/* Everything falls through in here */ | /* Everything falls through in here */ | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
} | } | ||||
} | } | ||||
} // matmul_general_4x16 | |||||
} // aarch64 | |||||
} // megdnn | |||||
} // namespace matmul_general_4x16 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -43,8 +43,9 @@ struct matmul_general_8x12 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -306,14 +307,13 @@ struct matmul_general_8x12 { | |||||
"6:\n" | "6:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
"x6", "x7", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -348,9 +348,9 @@ struct matmul_general_8x12 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -520,13 +520,12 @@ struct matmul_general_8x12 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -557,9 +556,9 @@ struct matmul_general_8x12 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int m_remain) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -717,13 +716,12 @@ struct matmul_general_8x12 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "x2", "x3", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"x2", "x3", "x10", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -754,9 +752,9 @@ struct matmul_general_8x12 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -895,20 +893,21 @@ struct matmul_general_8x12 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
[m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
"x3", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
"x10", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_A_n( | |||||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
float zerobuff[8]; | float zerobuff[8]; | ||||
std::memset(zerobuff, 0, sizeof(float) * 8); | std::memset(zerobuff, 0, sizeof(float) * 8); | ||||
constexpr int PACK_SIZE_32 = 4 * 8; | constexpr int PACK_SIZE_32 = 4 * 8; | ||||
@@ -933,8 +932,9 @@ struct matmul_general_8x12 { | |||||
prefetch_2x(inptr7); | prefetch_2x(inptr7); | ||||
int x = (kmax - k0); | int x = (kmax - k0); | ||||
for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
transpose_8x4_1_s( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
outptr += PACK_SIZE_32; | outptr += PACK_SIZE_32; | ||||
} | } | ||||
for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
@@ -1004,8 +1004,8 @@ struct matmul_general_8x12 { | |||||
} | } | ||||
} | } | ||||
static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_A_t( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
int ksize8 = (ksize << 3); | int ksize8 = (ksize << 3); | ||||
int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
@@ -1028,20 +1028,17 @@ struct matmul_general_8x12 { | |||||
auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
for (; x + 8 <= xmax; x += 8) { | for (; x + 8 <= xmax; x += 8) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize8; | outptr += ksize8; | ||||
} | } | ||||
outptr = outptr_base4; | outptr = outptr_base4; | ||||
for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
if (x < xmax) { | if (x < xmax) { | ||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||||
xmax - x); | |||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 4 * 8; | outptr_base += 4 * 8; | ||||
outptr_base4 += 4 * 4; | outptr_base4 += 4 * 4; | ||||
@@ -1071,8 +1068,8 @@ struct matmul_general_8x12 { | |||||
} | } | ||||
} | } | ||||
static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_B_n( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
int ksize12 = ksize * 12; | int ksize12 = ksize * 12; | ||||
int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
@@ -1095,20 +1092,17 @@ struct matmul_general_8x12 { | |||||
auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
for (; x + 12 <= xmax; x += 12) { | for (; x + 12 <= xmax; x += 12) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize12; | outptr += ksize12; | ||||
} | } | ||||
outptr = outptr_base4; | outptr = outptr_base4; | ||||
for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
outptr_interleave); | |||||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
if (x < xmax) { | if (x < xmax) { | ||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||||
xmax - x); | |||||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 12 * 4; | outptr_base += 12 * 4; | ||||
outptr_base4 += 4 * 4; | outptr_base4 += 4 * 4; | ||||
@@ -1138,8 +1132,8 @@ struct matmul_general_8x12 { | |||||
} | } | ||||
} | } | ||||
static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_B_t( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||||
float* outptr = out; | float* outptr = out; | ||||
const float* inptr = in; | const float* inptr = in; | ||||
float zerobuff[12]; | float zerobuff[12]; | ||||
@@ -1172,9 +1166,9 @@ struct matmul_general_8x12 { | |||||
prefetch_2x(inptr11); | prefetch_2x(inptr11); | ||||
int x = (kmax - k0); | int x = (kmax - k0); | ||||
for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, inptr8, inptr9, | |||||
inptr10, inptr11, outptr); | |||||
transpose_12x4_1_s( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
inptr8, inptr9, inptr10, inptr11, outptr); | |||||
outptr += 48; | outptr += 48; | ||||
} | } | ||||
for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
@@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 { | |||||
"6:\n" | "6:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||||
"x11", "x12", "x13", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
} | } | ||||
@@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"x8", "x9", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", | |||||
"x10", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int m_remain) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", | |||||
"x22", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
[m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
"x3", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
"x10", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 { | |||||
"6:\n" | "6:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||||
"x11", "x12", "x13", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
} | } | ||||
@@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int m_remain) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
@@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 { | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
[m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
"x3", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
"x10", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -307,10 +308,10 @@ struct matmul_mk4_8x12 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -340,9 +341,9 @@ struct matmul_mk4_8x12 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -500,8 +501,8 @@ struct matmul_mk4_8x12 { | |||||
[output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
[n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
@@ -531,8 +532,9 @@ struct matmul_mk4_8x12 { | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -669,9 +671,9 @@ struct matmul_mk4_8x12 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0) | [output0] "+r"(output0) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -697,9 +699,9 @@ struct matmul_mk4_8x12 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -818,15 +820,15 @@ struct matmul_mk4_8x12 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_A( | |||||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | ||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int PACK_SIZE_32 = 4 * 8; | constexpr int PACK_SIZE_32 = 4 * 8; | ||||
@@ -855,8 +857,8 @@ struct matmul_mk4_8x12 { | |||||
} | } | ||||
} | } | ||||
static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax) { | |||||
static void sgemm_8x12_pack_B( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
float tmpbuff[16] = {0.0f}; | float tmpbuff[16] = {0.0f}; | ||||
@@ -886,8 +888,7 @@ struct matmul_mk4_8x12 { | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
if (x < xmax) { | if (x < xmax) { | ||||
std::memcpy(tmpbuff, inptr, | |||||
sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
const float* tmp_ptr = &tmpbuff[0]; | const float* tmp_ptr = &tmpbuff[0]; | ||||
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | ||||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||||
"x11", "x12", "x13", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
"memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 { | |||||
[output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
[n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "x8", "x9", "x10", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
@@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 { | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0) | [output0] "+r"(output0) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 { | |||||
// +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||||
"x11", "x12", "x13", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
"memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_8x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_8x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
float* output0 = output; | float* output0 = output; | ||||
@@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 { | |||||
[output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
[n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "x8", "x9", "x10", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
@@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 { | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x12(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k) { | |||||
static void kern_4x12( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0) | [output0] "+r"(output0) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
"x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 { | |||||
// +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
// | // | ||||
// Accumulator | // Accumulator | ||||
static void kern_4x4(const float* packA, const float* packB, int K, | |||||
float* output, int LDC, bool is_first_k, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const float* packA, const float* packB, int K, float* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
@@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 { | |||||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_C | #undef STORE_C | ||||
@@ -10,6 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | ||||
@@ -17,44 +18,40 @@ | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | ||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace aarch64; | using namespace aarch64; | ||||
using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | ||||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
int k0, int kmax, bool transpose_A) const { | |||||
void sgemm_4x16::pack_A( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose_A) const { | |||||
if (transpose_A) { | if (transpose_A) { | ||||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
int k0, int kmax, bool transpose_B) const { | |||||
void sgemm_4x16::pack_B( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
if (transpose_B) { | if (transpose_B) { | ||||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
void sgemm_4x16::kern( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k, const float*, float*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
matmul_general_4x16::kern_4x16( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K16; | cur_packB += K16; | ||||
} | } | ||||
@@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | ||||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
int k0, int kmax, bool transpose_A) const { | |||||
void sgemm_8x12::pack_A( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose_A) const { | |||||
if (transpose_A) { | if (transpose_A) { | ||||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
int k0, int kmax, bool transpose_B) const { | |||||
void sgemm_8x12::pack_B( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
if (transpose_B) { | if (transpose_B) { | ||||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
template <typename gemm_class> | template <typename gemm_class> | ||||
static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
static inline void sgemm_8x12_helper( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
constexpr size_t A_INTERLEAVE4 = 4; | constexpr size_t A_INTERLEAVE4 = 4; | ||||
constexpr size_t B_INTERLEAVE = 12; | constexpr size_t B_INTERLEAVE = 12; | ||||
@@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
gemm_class::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
gemm_class::kern_4x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
gemm_class::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
} | } | ||||
} | } | ||||
void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
void sgemm_8x12::kern( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k, const float*, float*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
#if !MGB_ENABLE_CPUINFO | #if !MGB_ENABLE_CPUINFO | ||||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||||
#else | #else | ||||
auto arch = cpuinfo_get_current_core()->uarch; | auto arch = cpuinfo_get_current_core()->uarch; | ||||
#ifdef __IN_TEE_ENV__ | #ifdef __IN_TEE_ENV__ | ||||
arch = cpuinfo_uarch_unknown; | arch = cpuinfo_uarch_unknown; | ||||
#endif | #endif | ||||
if (arch == cpuinfo_uarch_cortex_a53) { | if (arch == cpuinfo_uarch_cortex_a53) { | ||||
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
sgemm_8x12_helper<matmul_general_8x12_a53>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} else if (arch == cpuinfo_uarch_cortex_a55) { | } else if (arch == cpuinfo_uarch_cortex_a55) { | ||||
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
sgemm_8x12_helper<matmul_general_8x12_a55>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} else { | } else { | ||||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
sgemm_8x12_helper<matmul_general_8x12>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | ||||
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose_A) const { | |||||
void sgemm_mk4_8x12::pack_A( | |||||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose_A) const { | |||||
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | ||||
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | ||||
} | } | ||||
void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
void sgemm_mk4_8x12::pack_B( | |||||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | ||||
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | ||||
} | } | ||||
template <typename gemm_name> | template <typename gemm_name> | ||||
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
static inline void sgemm_mk4_8x12_helper( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
const int K12 = K * 12; | const int K12 = K * 12; | ||||
const int K8 = K * 8; | const int K8 = K * 8; | ||||
const int K4 = K * 4; | const int K4 = K * 4; | ||||
@@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
gemm_name::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
gemm_name::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
packA += K4; | packA += K4; | ||||
} | } | ||||
} | } | ||||
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
void sgemm_mk4_8x12::kern( | |||||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k, const float*, float*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | ||||
#if !MGB_ENABLE_CPUINFO | #if !MGB_ENABLE_CPUINFO | ||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||||
#else | #else | ||||
auto arch = cpuinfo_get_current_core()->uarch; | auto arch = cpuinfo_get_current_core()->uarch; | ||||
#ifdef __IN_TEE_ENV__ | #ifdef __IN_TEE_ENV__ | ||||
arch = cpuinfo_uarch_unknown; | arch = cpuinfo_uarch_unknown; | ||||
#endif | #endif | ||||
if (arch == cpuinfo_uarch_cortex_a53) { | if (arch == cpuinfo_uarch_cortex_a53) { | ||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} else if (arch == cpuinfo_uarch_cortex_a55) { | } else if (arch == cpuinfo_uarch_cortex_a55) { | ||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} else { | } else { | ||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>( | |||||
packA, packB, M, N, K, C, LDC, is_first_k); | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
@@ -15,17 +15,14 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||||
sgemm_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||||
sgemm_4x16); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, | |||||
sgemm_mk4_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||||
sgemm_nopack_4x16); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -20,8 +20,8 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
void kern_4x1( | |||||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
LDB *= sizeof(float); | LDB *= sizeof(float); | ||||
asm volatile( | asm volatile( | ||||
"subs %w[K], %w[K], #4\n" | "subs %w[K], %w[K], #4\n" | ||||
@@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||||
"memory"); | |||||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
// +--------+ - - - - -+--------+ | // +--------+ - - - - -+--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
void kern_4x4( | |||||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 12) * sizeof(float); | LDB = (LDB - 12) * sizeof(float); | ||||
@@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
"v18", "v19", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||||
"v19", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
// +--------+ - - - - -+--------+ | // +--------+ - - - - -+--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
float* output) { | |||||
void kern_4x8( | |||||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 24) * sizeof(float); | LDB = (LDB - 24) * sizeof(float); | ||||
@@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||||
"v27", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", | |||||
"memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
// +--------+ | // +--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||||
float* output) { | |||||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) { | |||||
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 56) * sizeof(float); | LDB = (LDB - 56) * sizeof(float); | ||||
@@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
} | } | ||||
} // namespace | } // namespace | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | ||||
void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||||
size_t LDB, float* C, size_t LDC, size_t M, | |||||
size_t K, size_t N, const float*, void*, bool trA, | |||||
bool trB) const { | |||||
void sgemm_nopack_4x16::kern( | |||||
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||||
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||||
constexpr static size_t MB = 4; | constexpr static size_t MB = 4; | ||||
constexpr static size_t KB = 4; | constexpr static size_t KB = 4; | ||||
constexpr static size_t NB = 16; | constexpr static size_t NB = 16; | ||||
@@ -46,8 +46,9 @@ namespace matmul_12x8x1 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_12x8( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
"stp q25, q26, [x9]\n" | "stp q25, q26, [x9]\n" | ||||
"stp q27, q28, [x10]\n" | "stp q27, q28, [x10]\n" | ||||
"stp q29, q30, [x11]\n" | "stp q29, q30, [x11]\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[output] "+r"(output) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
: | : | ||||
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", | |||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", | |||||
"x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", | |||||
"cc", "memory"); | |||||
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5", | |||||
"x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x8( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
"stp q17, q18, [x5]\n" | "stp q17, q18, [x5]\n" | ||||
"stp q19, q20, [x6]\n" | "stp q19, q20, [x6]\n" | ||||
"stp q21, q22, [x7]\n" | "stp q21, q22, [x7]\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[output] "+r"(output) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
: | : | ||||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", | |||||
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4", | |||||
"x5", "x6", "x7", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
size_t m_remain) { | |||||
static void kern_4x8( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
[m_remain] "+r"(m_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"cc", "memory"); | |||||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
size_t n_remain) { | |||||
static void kern_12x4( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t n_remain) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||||
[outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||||
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), | |||||
[x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
[outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||||
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0), | |||||
[n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
"v16", "v17", "v18", "v19", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
size_t n_remain) { | |||||
static void kern_8x4( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t n_remain) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
[n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
"cc", "memory"); | |||||
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
size_t n_remain) { | |||||
static void kern_4x4( | |||||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
@@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
[m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||||
[n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | ||||
@@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s16_12x8x1_pack_A_n( | |||||
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int16_t zerobuff[4]; | int16_t zerobuff[4]; | ||||
std::memset(zerobuff, 0, sizeof(int16_t) * 4); | std::memset(zerobuff, 0, sizeof(int16_t) * 4); | ||||
@@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, inptr8, inptr9, inptr10, | |||||
inptr11, outptr); | |||||
interleave_12x1_4_h( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
inptr8, inptr9, inptr10, inptr11, outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||||
outptr, 1, K); | |||||
interleave_12( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
inptr8, inptr9, inptr10, inptr11, outptr, 1, K); | |||||
} | } | ||||
} | } | ||||
@@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
interleave_8x1_8_h( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 1, K); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 1, K); | |||||
} | } | ||||
} | } | ||||
@@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||||
int ldin, int x0, int xmax, | |||||
int k0, int kmax) { | |||||
static void gemm_s16_12x8x1_transpose_pack_A_n( | |||||
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
const int ksize8 = ksize4 * 2; | const int ksize8 = ksize4 * 2; | ||||
@@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||||
} | } | ||||
} | } | ||||
static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s16_12x8x1_pack_B_n( | |||||
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
const int ksize8 = ksize4 * 2; | const int ksize8 = ksize4 * 2; | ||||
@@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
const int16_t* inptr, int ldin, | |||||
int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s16_12x8x1_transpose_pack_B_n( | |||||
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int16_t zerobuff[4]; | int16_t zerobuff[4]; | ||||
std::memset(zerobuff, 0, sizeof(int16_t) * 4); | std::memset(zerobuff, 0, sizeof(int16_t) * 4); | ||||
@@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
interleave_8x1_8_h( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 1, K); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 1, K); | |||||
} | } | ||||
} | } | ||||
@@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -22,39 +22,37 @@ using namespace aarch64::matmul; | |||||
///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | ///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | ||||
void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s16_12x8x1::pack_A( | |||||
dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||||
y0, ymax, k0, kmax); | |||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||||
k0, kmax); | |||||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s16_12x8x1::pack_B( | |||||
dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
size_t M, size_t N, size_t K, dt_int32* C, | |||||
size_t LDC, bool is_first_k, const dt_int32*, | |||||
dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||||
C_dtype.enumv() == DTypeEnum::Int32), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s16_12x8x1::kern( | |||||
const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||||
C_dtype.enumv() == DTypeEnum::Int32), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_12x8x1::kern_12x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_12x8x1::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4)); | |||||
matmul_12x8x1::kern_4x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_12x8x1::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -16,11 +16,11 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||||
gemm_s16_12x8x1); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||||
true, gemm_nopack_s16_8x8); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -9,8 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -20,8 +20,9 @@ using namespace aarch64::matmul; | |||||
namespace { | namespace { | ||||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
void kern_8x1( | |||||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | ||||
//! here. | //! here. | ||||
LDB *= sizeof(dt_int16); | LDB *= sizeof(dt_int16); | ||||
@@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
// | v31[0-7]| |v23[0-3]| | // | v31[0-7]| |v23[0-3]| | ||||
// +---------+ +--------+ | // +---------+ +--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
void kern_8x4( | |||||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 24) * sizeof(dt_int16); | LDB = (LDB - 24) * sizeof(dt_int16); | ||||
@@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
"memory"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
// | v7[0-7]| |v30[0-3]|v31[0-3]| | // | v7[0-7]| |v30[0-3]|v31[0-3]| | ||||
// +--------+ +--------+--------+ | // +--------+ +--------+--------+ | ||||
// Accumulator | // Accumulator | ||||
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
void kern_8x8( | |||||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
dt_int32* output) { | |||||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | ||||
//! here. | //! here. | ||||
LDB = (LDB - 48) * sizeof(dt_int16); | LDB = (LDB - 48) * sizeof(dt_int16); | ||||
@@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
[output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"cc", "memory"); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | ||||
void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||||
size_t K, size_t N, const dt_int32*, void*, | |||||
bool trA, bool trB) const { | |||||
void gemm_nopack_s16_8x8::kern( | |||||
const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, | |||||
size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, | |||||
bool trB) const { | |||||
constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
@@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void s4_kern_8x8_remain( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
@@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
"dup v5.8b,v20.b[5]\n" | "dup v5.8b,v20.b[5]\n" | ||||
"dup v6.8b,v20.b[6]\n" | "dup v6.8b,v20.b[6]\n" | ||||
"dup v7.8b,v20.b[7]\n" | "dup v7.8b,v20.b[7]\n" | ||||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | "ld1 {v17.8b}, [%[b_ptr]], 8\n" | ||||
"dup v8.8b,v20.b[8]\n" | "dup v8.8b,v20.b[8]\n" | ||||
@@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
STORE_C | STORE_C | ||||
: | : | ||||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"( | |||||
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
: | : | ||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31"); | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void s4_kern_8x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
// clang-format off | |||||
// clang-format off | |||||
#define LOAD_C_8 \ | #define LOAD_C_8 \ | ||||
"ld1 {v24.8h}, [x0], #16\n" \ | "ld1 {v24.8h}, [x0], #16\n" \ | ||||
@@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"st1 {v28.8h}, [x4], #16\n" \ | "st1 {v28.8h}, [x4], #16\n" \ | ||||
"st1 {v29.8h}, [x5], #16\n" \ | "st1 {v29.8h}, [x5], #16\n" \ | ||||
"st1 {v30.8h}, [x6], #16\n" \ | "st1 {v30.8h}, [x6], #16\n" \ | ||||
"st1 {v31.8h}, [x7], #16\n" \ | |||||
"st1 {v31.8h}, [x7], #16\n" | |||||
// clang-format on | |||||
// clang-format on | |||||
register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
asm volatile( | asm volatile( | ||||
"add x1, x0, %x[LDC]\n" | "add x1, x0, %x[LDC]\n" | ||||
@@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | ||||
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | ||||
"1:\n" | "1:\n" | ||||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
"dup v0.8b,v20.b[0]\n" | "dup v0.8b,v20.b[0]\n" | ||||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | "ld1 {v22.16b}, [%[a_ptr]],#16\n" | ||||
"dup v1.8b,v20.b[1]\n" | "dup v1.8b,v20.b[1]\n" | ||||
@@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"dup v5.8b,v20.b[5]\n" | "dup v5.8b,v20.b[5]\n" | ||||
"dup v6.8b,v20.b[6]\n" | "dup v6.8b,v20.b[6]\n" | ||||
"dup v7.8b,v20.b[7]\n" | "dup v7.8b,v20.b[7]\n" | ||||
"dup v8.8b,v20.b[8]\n" | "dup v8.8b,v20.b[8]\n" | ||||
"smlal v24.8h, v0.8b, v16.8b\n" | "smlal v24.8h, v0.8b, v16.8b\n" | ||||
@@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
STORE_C_8 | STORE_C_8 | ||||
: | : | ||||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"( | |||||
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
: | : | ||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31"); | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
//packa | |||||
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
// packa | |||||
static void gemm_s4x4x16_8x8x8_transpose_pack( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[8]; | int8_t zerobuff[8]; | ||||
int8_t tmpbuff0[8]; | int8_t tmpbuff0[8]; | ||||
int8_t tmpbuff1[8]; | int8_t tmpbuff1[8]; | ||||
@@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
prefetch_2x(inptr5); | prefetch_2x(inptr5); | ||||
prefetch_2x(inptr6); | prefetch_2x(inptr6); | ||||
prefetch_2x(inptr7); | prefetch_2x(inptr7); | ||||
int K = (kmax - k0)/2; | |||||
int K = (kmax - k0) / 2; | |||||
//! read 4 * 16 in each row | //! read 4 * 16 in each row | ||||
for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
transpose_4x8_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
std::memcpy(tmpbuff0,inptr0,K); | |||||
std::memcpy(tmpbuff1,inptr1,K); | |||||
std::memcpy(tmpbuff2,inptr2,K); | |||||
std::memcpy(tmpbuff3,inptr3,K); | |||||
std::memcpy(tmpbuff4,inptr4,K); | |||||
std::memcpy(tmpbuff5,inptr5,K); | |||||
std::memcpy(tmpbuff6,inptr6,K); | |||||
std::memcpy(tmpbuff7,inptr7,K); | |||||
std::memcpy(tmpbuff0, inptr0, K); | |||||
std::memcpy(tmpbuff1, inptr1, K); | |||||
std::memcpy(tmpbuff2, inptr2, K); | |||||
std::memcpy(tmpbuff3, inptr3, K); | |||||
std::memcpy(tmpbuff4, inptr4, K); | |||||
std::memcpy(tmpbuff5, inptr5, K); | |||||
std::memcpy(tmpbuff6, inptr6, K); | |||||
std::memcpy(tmpbuff7, inptr7, K); | |||||
inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
@@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
inptr5 = tmpbuff5; | inptr5 = tmpbuff5; | ||||
inptr6 = tmpbuff6; | inptr6 = tmpbuff6; | ||||
inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
transpose_4x8_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
} | } | ||||
for (; y < ymax; y += 8) { | for (; y < ymax; y += 8) { | ||||
@@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
const int8_t* inptr6 = inptr5 + ldin; | const int8_t* inptr6 = inptr5 + ldin; | ||||
const int8_t* inptr7 = inptr6 + ldin; | const int8_t* inptr7 = inptr6 + ldin; | ||||
int K = (kmax - k0)/2; | |||||
int K = (kmax - k0) / 2; | |||||
//! read 4 * 16 in each row | //! read 4 * 16 in each row | ||||
for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
if (y + 7 >= ymax) { | if (y + 7 >= ymax) { | ||||
switch (y + 7 - ymax) { | switch (y + 7 - ymax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
megdnn_assert(0); | megdnn_assert(0); | ||||
} | } | ||||
} | } | ||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
transpose_4x8_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
if (y + 7 >= ymax) { | if (y + 7 >= ymax) { | ||||
switch (y + 7 - ymax) { | switch (y + 7 - ymax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
} | } | ||||
} | } | ||||
std::memcpy(tmpbuff0,inptr0,K); | |||||
std::memcpy(tmpbuff1,inptr1,K); | |||||
std::memcpy(tmpbuff2,inptr2,K); | |||||
std::memcpy(tmpbuff3,inptr3,K); | |||||
std::memcpy(tmpbuff4,inptr4,K); | |||||
std::memcpy(tmpbuff5,inptr5,K); | |||||
std::memcpy(tmpbuff6,inptr6,K); | |||||
std::memcpy(tmpbuff7,inptr7,K); | |||||
std::memcpy(tmpbuff0, inptr0, K); | |||||
std::memcpy(tmpbuff1, inptr1, K); | |||||
std::memcpy(tmpbuff2, inptr2, K); | |||||
std::memcpy(tmpbuff3, inptr3, K); | |||||
std::memcpy(tmpbuff4, inptr4, K); | |||||
std::memcpy(tmpbuff5, inptr5, K); | |||||
std::memcpy(tmpbuff6, inptr6, K); | |||||
std::memcpy(tmpbuff7, inptr7, K); | |||||
inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
@@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
inptr5 = tmpbuff5; | inptr5 = tmpbuff5; | ||||
inptr6 = tmpbuff6; | inptr6 = tmpbuff6; | ||||
inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
inptr5, inptr6, inptr7, outptr); | |||||
transpose_4x8_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
//packb | |||||
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
// packb | |||||
static void gemm_s4x4x16_8x8x8_interleave_pack( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[8]; | int8_t zerobuff[8]; | ||||
int8_t tmpbuff0[8]; | int8_t tmpbuff0[8]; | ||||
int8_t tmpbuff1[8]; | int8_t tmpbuff1[8]; | ||||
@@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | ||||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||||
const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4 | |||||
int8_t* outptr = out; | int8_t* outptr = out; | ||||
int8_t* outptr_interleave = nullptr; | int8_t* outptr_interleave = nullptr; | ||||
@@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
int8_t* outptr_inner = outptr; | int8_t* outptr_inner = outptr; | ||||
for (; x + 3 < xmax; x += 4) { | for (; x + 3 < xmax; x += 4) { | ||||
outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x4_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
} | } | ||||
if (x < xmax) { | if (x < xmax) { | ||||
int remainx = xmax - x; | int remainx = xmax - x; | ||||
std::memcpy(tmpbuff0,inptr0,remainx); | |||||
std::memcpy(tmpbuff1,inptr1,remainx); | |||||
std::memcpy(tmpbuff2,inptr2,remainx); | |||||
std::memcpy(tmpbuff3,inptr3,remainx); | |||||
std::memcpy(tmpbuff4,inptr4,remainx); | |||||
std::memcpy(tmpbuff5,inptr5,remainx); | |||||
std::memcpy(tmpbuff6,inptr6,remainx); | |||||
std::memcpy(tmpbuff7,inptr7,remainx); | |||||
std::memcpy(tmpbuff0, inptr0, remainx); | |||||
std::memcpy(tmpbuff1, inptr1, remainx); | |||||
std::memcpy(tmpbuff2, inptr2, remainx); | |||||
std::memcpy(tmpbuff3, inptr3, remainx); | |||||
std::memcpy(tmpbuff4, inptr4, remainx); | |||||
std::memcpy(tmpbuff5, inptr5, remainx); | |||||
std::memcpy(tmpbuff6, inptr6, remainx); | |||||
std::memcpy(tmpbuff7, inptr7, remainx); | |||||
inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
@@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x4_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
} | } | ||||
outptr += 64; | outptr += 64; | ||||
@@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
break; | break; | ||||
} | } | ||||
outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x4_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
} | } | ||||
if (x < xmax) { | if (x < xmax) { | ||||
@@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
int remainx = xmax - x; | int remainx = xmax - x; | ||||
outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
std::memcpy(tmpbuff0,inptr0,remainx); | |||||
std::memcpy(tmpbuff1,inptr1,remainx); | |||||
std::memcpy(tmpbuff2,inptr2,remainx); | |||||
std::memcpy(tmpbuff3,inptr3,remainx); | |||||
std::memcpy(tmpbuff4,inptr4,remainx); | |||||
std::memcpy(tmpbuff5,inptr5,remainx); | |||||
std::memcpy(tmpbuff6,inptr6,remainx); | |||||
std::memcpy(tmpbuff7,inptr7,remainx); | |||||
std::memcpy(tmpbuff0, inptr0, remainx); | |||||
std::memcpy(tmpbuff1, inptr1, remainx); | |||||
std::memcpy(tmpbuff2, inptr2, remainx); | |||||
std::memcpy(tmpbuff3, inptr3, remainx); | |||||
std::memcpy(tmpbuff4, inptr4, remainx); | |||||
std::memcpy(tmpbuff5, inptr5, remainx); | |||||
std::memcpy(tmpbuff6, inptr6, remainx); | |||||
std::memcpy(tmpbuff7, inptr7, remainx); | |||||
inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
@@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x4_1_b_with_shift( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} // namespace matmul_4x4x16 | |||||
} // namespace matmul_s4_4x4x16 | |||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -10,9 +10,9 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | ||||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -23,39 +23,38 @@ using namespace aarch64::matmul; | |||||
// ===========================gemm_s4x4x16_s4_8x8x8================================== | // ===========================gemm_s4x4x16_s4_8x8x8================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | ||||
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s4x4x16_s4_8x8x8::pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||||
out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||||
out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s4x4x16_s4_8x8x8::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s4x4x16_s4_8x8x8::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
matmul_s4_4x4x16::s4_kern_8x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
B_INTERLEAVE); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
matmul_s4_4x4x16::s4_kern_8x8_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
@@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
matmul_s4_4x4x16::s4_kern_8x8_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
@@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
} | } | ||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -17,8 +17,8 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||||
gemm_s4x4x16_s4_8x8x8); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -51,8 +51,9 @@ namespace matmul_4x4x16 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 16; | K /= 16; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
); | ); | ||||
} | } | ||||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
int m_remain, int n_remain) { | |||||
static void kern_4x4_remain( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
megdnn_assert(K > 0); | megdnn_assert(K > 0); | ||||
K /= 16; | K /= 16; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
@@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
STORE_C | STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[output] "+r"(output), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), | |||||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, int kmax) { | |||||
static void gemm_s8_4x4_pack_A_n( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8_4x4_pack_B_n( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (remain >= 0) { | if (remain >= 0) { | ||||
switch (remain) { | switch (remain) { | ||||
case 7: | case 7: | ||||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr0 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
transpose_4x16_1_b_helper( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
outptr_inner += ksize4; | outptr_inner += ksize4; | ||||
} | } | ||||
@@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (remain >= 0) { | if (remain >= 0) { | ||||
switch (remain) { | switch (remain) { | ||||
case 7: | case 7: | ||||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr0 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"stp q18, q19, [x5]\n" | "stp q18, q19, [x5]\n" | ||||
"stp q20, q21, [x6]\n" | "stp q20, q21, [x6]\n" | ||||
"stp q22, q23, [x7]\n" | "stp q22, q23, [x7]\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[output] "+r"(output) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v26", "v27", "x1", | |||||
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"cc", "memory"); | |||||
} | } | ||||
/** | /** | ||||
@@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
size_t n_remain) { | |||||
static void kern_8x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t n_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
[n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, | |||||
size_t m_remain) { | |||||
static void kern_4x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
[outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
size_t n_remain) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
[x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | ||||
@@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax) { | |||||
static void gemm_s8_8x8_pack_A_n( | |||||
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
interleave_8x8_2_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 8, K); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 8, K); | |||||
} | } | ||||
} | } | ||||
@@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8_8x8_transpose_pack_A_n( | |||||
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
megdnn_assert(0); | megdnn_assert(0); | ||||
} | } | ||||
} | } | ||||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
transpose_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
outptr += ksize8; | outptr += ksize8; | ||||
} | } | ||||
@@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
} | } | ||||
} | } | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 4, 4); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 4, 4); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
} | } | ||||
} | } | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 4, xmax - x); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
@@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8_8x8_pack_B_n( | |||||
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
} | } | ||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr += ksize8; | outptr += ksize8; | ||||
} | } | ||||
@@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr_interleave, 4, 4); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave, 4, 4); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr_interleave, 4, xmax - x); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
@@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8_8x8_transpose_pack_B_n( | |||||
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
constexpr int interleave4 = 32; | constexpr int interleave4 = 32; | ||||
@@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
transpose_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
outptr += interleave8; | outptr += interleave8; | ||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 8, K); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 8, K); | |||||
outptr += interleave8; | outptr += interleave8; | ||||
} | } | ||||
} | } | ||||
@@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, bool is_first_k) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||||
bool is_first_k) { | |||||
K = div_ceil(K, 16); | K = div_ceil(K, 16); | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"6:\n" | "6:\n" | ||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[k] "+r"(K), [output] "+r"(output) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"cc", "memory"); | |||||
} | } | ||||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, bool is_first_k, size_t remain_n) { | |||||
static void kern_4x4_remain( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||||
bool is_first_k, size_t remain_n) { | |||||
K = div_ceil(K, 16); | K = div_ceil(K, 16); | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
"7:\n" | "7:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||||
[k] "+r"(K), [output] "+r"(output) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), | |||||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"cc", "memory"); | |||||
} | } | ||||
static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_mk4_s8_4x4_pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | ||||
int8_t zerobuff[4][64]; | int8_t zerobuff[4][64]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | ||||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||||
"mk4 matmul with m is not times of 4"); | |||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
megdnn_assert( | |||||
ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||||
"mk4 matmul with m is not times of 4"); | |||||
megdnn_assert( | |||||
kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
size_t roundk = round_up(kmax - k0, 16); | size_t roundk = round_up(kmax - k0, 16); | ||||
size_t out_offset = roundk * 4; | size_t out_offset = roundk * 4; | ||||
int y = y0; | int y = y0; | ||||
@@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
prefetch_2x(inptr3); | prefetch_2x(inptr3); | ||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
out_offset); | |||||
transpose_interleave_4x4_4_b( | |||||
inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||||
output += 64; | output += 64; | ||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
@@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
inptr1 = zerobuff[1]; | inptr1 = zerobuff[1]; | ||||
inptr2 = zerobuff[2]; | inptr2 = zerobuff[2]; | ||||
inptr3 = zerobuff[3]; | inptr3 = zerobuff[3]; | ||||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
out_offset); | |||||
transpose_interleave_4x4_4_b( | |||||
inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||||
output += 64; | output += 64; | ||||
} | } | ||||
} | } | ||||
@@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_mk4_s8_4x4_pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int32_t zerobuff[4]; | int32_t zerobuff[4]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ICB = (ksize) / 4; | const int ICB = (ksize) / 4; | ||||
const int ksize4 = round_up<int>(ICB, 4) * 4; | const int ksize4 = round_up<int>(ICB, 4) * 4; | ||||
int32_t* outptr = reinterpret_cast<int32_t*>(out); | int32_t* outptr = reinterpret_cast<int32_t*>(out); | ||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
megdnn_assert( | |||||
kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
int k = k0 / 4; | int k = k0 / 4; | ||||
for (; k + 3 < ICB; k += 4) { | for (; k + 3 < ICB; k += 4) { | ||||
const int32_t* inptr0 = | |||||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr1 = | const int32_t* inptr1 = | ||||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | ||||
const int32_t* inptr2 = | const int32_t* inptr2 = | ||||
@@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
outptr += 4 * 4; | outptr += 4 * 4; | ||||
} | } | ||||
if (k < ICB) { | if (k < ICB) { | ||||
const int32_t* inptr0 = | |||||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
const int32_t* inptr1 = | const int32_t* inptr1 = | ||||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | ||||
const int32_t* inptr2 = | const int32_t* inptr2 = | ||||
@@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (k + 3 >= ICB) { | if (k + 3 >= ICB) { | ||||
switch (k + 3 - ICB) { | switch (k + 3 - ICB) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (k + 3 >= ICB) { | if (k + 3 >= ICB) { | ||||
switch (k + 3 - ICB) { | switch (k + 3 - ICB) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
} // namespace matmul_4x4x16 | |||||
} // namespace matmul_mk4_4x4x16 | |||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||||
///////////////////////// gemm_s8_4x4 //////////////////////////////////// | ///////////////////////// gemm_s8_4x4 //////////////////////////////////// | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | ||||
void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8_4x4::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
void gemm_s8_4x4::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
} else { | } else { | ||||
@@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
} | } | ||||
} | } | ||||
void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s8_4x4::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_4x4x16::kern_4x4_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, 4, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | ||||
void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert(!transpose, | |||||
"the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
void gemm_mk4_s8_4x4::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
megdnn_assert( | |||||
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
megdnn_assert(!transpose, | |||||
"the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
void gemm_mk4_s8_4x4::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert( | |||||
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_mk4_s8_4x4::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||||
is_first_k); | |||||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k); | |||||
output += B_INTERLEAVE * 4; | output += B_INTERLEAVE * 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
if (n < N) { | if (n < N) { | ||||
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||||
is_first_k, N - n); | |||||
matmul_mk4_4x4x16::kern_4x4_remain( | |||||
packA, cur_packB, K, output, is_first_k, N - n); | |||||
} | } | ||||
packA += K4; | packA += K4; | ||||
} | } | ||||
} | } | ||||
///////////////////////// gemm_s8_8x8 //////////////////////////////////// | ///////////////////////// gemm_s8_8x8 //////////////////////////////////// | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | ||||
void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8_8x8::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
ymax, k0, kmax); | |||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n( | |||||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
void gemm_s8_8x8::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
} | } | ||||
} | } | ||||
void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s8_8x8::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x8x8::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
matmul_8x8x8::kern_4x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x8x8::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -16,14 +16,14 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||||
gemm_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||||
gemm_mk4_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
gemm_s8_8x8); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -52,8 +52,9 @@ namespace matmul_8x12x4 { | |||||
#if 1 | #if 1 | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
} | } | ||||
#else | #else | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
"stp q15, q23, [%[outptr7]]\n" | "stp q15, q23, [%[outptr7]]\n" | ||||
"str q31, [%[outptr7], #32]\n" | "str q31, [%[outptr7], #32]\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), | |||||
[a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), | |||||
[b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), | |||||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), | |||||
[a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), | |||||
[b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), | |||||
[is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
[outptr7] "=r"(outptr7) | |||||
: | : | ||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "cc", "memory"); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, int m_remain) { | |||||
static void kern_4x12( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int m_remain) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
"4:\n" STORE_C | "4:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | ||||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||||
[is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), | |||||
[LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||||
[b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), | |||||
[b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
[outptr3] "=r"(outptr3), [x0] "=r"(x0) | |||||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0), | |||||
[a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), | |||||
[b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "=r"(x0) | |||||
: | : | ||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "memory", "cc"); | |||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "memory", "cc"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
static void kern_8x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | ||||
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | ||||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | ||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), | |||||
[outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||||
[x0] "=r"(x0) | |||||
: | : | ||||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||||
"cc"); | |||||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
// | // | ||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 4; | K /= 4; | ||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | ||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | ||||
@@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"4:\n" STORE_C | "4:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | ||||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | ||||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), | |||||
[a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||||
[b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
[outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | ||||
: | : | ||||
: "v4", "v5", "v6", "v7", "memory", "cc"); | : "v4", "v5", "v6", "v7", "memory", "cc"); | ||||
@@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8_8x12_pack_A_n( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
//! read 8 * 4 in each row | //! read 8 * 4 in each row | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
interleave_8x4_4_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 4, K); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 4, K); | |||||
} | } | ||||
} | } | ||||
for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
@@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8_8x12_pack_A_t( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8_8x12_pack_B_n( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8_8x12_pack_B_t( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
//! read 12 * 4 in each row | //! read 12 * 4 in each row | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, inptr8, inptr9, inptr10, | |||||
inptr11, outptr); | |||||
interleave_12x4_4_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
inptr8, inptr9, inptr10, inptr11, outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||||
outptr, 4, K); | |||||
interleave_12( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
inptr8, inptr9, inptr10, inptr11, outptr, 4, K); | |||||
} | } | ||||
} | } | ||||
for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
@@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 { | |||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x12( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
static void kern_4x12( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
"stp q16, q17, [%[outptr0], #128]\n" | "stp q16, q17, [%[outptr0], #128]\n" | ||||
"stp q18, q19, [%[outptr0], #160]\n" | "stp q18, q19, [%[outptr0], #160]\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | ||||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||||
[is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||||
[b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), | |||||
[b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), | |||||
[b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||||
: | : | ||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
"v17", "v18", "v19", "memory", "cc"); | |||||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
"v19", "memory", "cc"); | |||||
} | } | ||||
// Overview of register layout: | // Overview of register layout: | ||||
@@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
static void kern_8x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
K /= 4; | K /= 4; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | ||||
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | ||||
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | ||||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||||
[x0] "=r"(x0) | |||||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0) | |||||
: | : | ||||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||||
"cc"); | |||||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
// Accumulator | // Accumulator | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
bool is_first_k, int n_remain) { | |||||
K /= 4; | K /= 4; | ||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | ||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | ||||
@@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"4:\n" STORE_C | "4:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | ||||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||||
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), | |||||
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||||
[x0] "=r"(x0) | |||||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||||
[b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0) | |||||
: | : | ||||
: "v4", "v5", "v6", "v7", "memory", "cc"); | : "v4", "v5", "v6", "v7", "memory", "cc"); | ||||
@@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, | |||||
"mk4 matmul with m is not times of 4"); | |||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | |||||
"mk4 matmul with k is not times of 4"); | |||||
static void gemm_mk4_s8_8x12_pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4"); | |||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4"); | |||||
int y = y0; | int y = y0; | ||||
int start_y = y0 / 4; | int start_y = y0 / 4; | ||||
for (; y + 7 < ymax; y += 8, start_y += 2) { | for (; y + 7 < ymax; y += 8, start_y += 2) { | ||||
@@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
interleave_2x4_4_b(inptr0, inptr1, outptr); | interleave_2x4_4_b(inptr0, inptr1, outptr); | ||||
} | } | ||||
} | } | ||||
for (; y + 3 < ymax; y += 4, start_y ++) { | |||||
for (; y + 3 < ymax; y += 4, start_y++) { | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | ||||
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | ||||
} | } | ||||
} | } | ||||
static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_mk4_s8_8x12_pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ksize12 = ksize * 12; | const int ksize12 = ksize * 12; | ||||
const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
@@ -12,10 +12,10 @@ | |||||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | ||||
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace aarch64; | using namespace aarch64; | ||||
@@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||||
/* ====================== gemm_s8_8x12 ===========================*/ | /* ====================== gemm_s8_8x12 ===========================*/ | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | ||||
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8_8x12::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, bool transpose) const { | |||||
void gemm_s8_8x12::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | ||||
} else { | } else { | ||||
@@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
} | } | ||||
} | } | ||||
void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s8_8x12::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
@@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_8x12x4::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4)); | |||||
matmul_8x12x4::kern_4x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x12x4::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
/* ====================== gemm_mk4_s8_8x12 ===========================*/ | /* ====================== gemm_mk4_s8_8x12 ===========================*/ | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | ||||
void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||||
kmax); | |||||
void gemm_mk4_s8_8x12::pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool transpose) const { | |||||
megdnn_assert( | |||||
!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||||
void gemm_mk4_s8_8x12::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
megdnn_assert( | |||||
!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | ||||
} | } | ||||
void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int32* C, | |||||
size_t LDC, bool is_first_k, const dt_int32*, | |||||
dt_int32*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_mk4_s8_8x12::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && | |||||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int32) || | |||||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
@@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += (B_INTERLEAVE << 2); | output += (B_INTERLEAVE << 2); | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_mk4_8x12x4::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 16; | output += 16; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += (B_INTERLEAVE << 2); | output += (B_INTERLEAVE << 2); | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
matmul_mk4_8x12x4::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 16; | output += 16; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -16,14 +16,14 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul { | namespace matmul { | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||||
gemm_s8_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||||
gemm_mk4_s8_8x12); | |||||
MEGDNN_REG_GEMM_STRATEGY( | |||||
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12); | |||||
} // namespace aarch64 | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -34,9 +34,9 @@ namespace matmul_4x4x16 { | |||||
* | * | ||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 16; | K /= 16; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
// Store back into memory | // Store back into memory | ||||
STORE_C | STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||||
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
"v31"); | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3", | |||||
"v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
"v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_4x4_pack_A_n( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8x8x16_4x4_pack_B_n( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (remain >= 0) { | if (remain >= 0) { | ||||
switch (remain) { | switch (remain) { | ||||
case 7: | case 7: | ||||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr0 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
transpose_4x16_1_b_helper( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_inner); | |||||
outptr_inner += ksize4; | outptr_inner += ksize4; | ||||
} | } | ||||
@@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (remain >= 0) { | if (remain >= 0) { | ||||
switch (remain) { | switch (remain) { | ||||
case 7: | case 7: | ||||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr0 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||||
* | * | ||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k) { | |||||
static void kern_8x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"bne 2b\n" | "bne 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", | |||||
"x4", "x5", "x6", "x7", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5", | |||||
"x6", "x7", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, | |||||
size_t n_remain) { | |||||
static void kern_8x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, size_t n_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
[n_remain] "+r"(n_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, | |||||
size_t m_remain) { | |||||
static void kern_4x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
[m_remain] "+r"(m_remain) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | ||||
"memory"); | "memory"); | ||||
@@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
* | * | ||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
size_t n_remain) { | |||||
static void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
K /= 8; | K /= 8; | ||||
const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"3:\n" STORE_C | "3:\n" STORE_C | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | ||||
[K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | ||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), | |||||
[n_remain] "+r"(n_remain) | |||||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", | |||||
"cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc", | |||||
"memory"); | |||||
#undef LOAD_LINE | #undef LOAD_LINE | ||||
#undef LOAD_C | #undef LOAD_C | ||||
@@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_8x8_pack_A_n( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
interleave_8x8_2_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 8, K); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 8, K); | |||||
} | } | ||||
} | } | ||||
for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
@@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, | |||||
int k0, int kmax) { | |||||
static void gemm_s8x8x16_8x8_transpose_pack_A_n( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
@@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
megdnn_assert(0); | megdnn_assert(0); | ||||
} | } | ||||
} | } | ||||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
transpose_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
outptr += ksize8; | outptr += ksize8; | ||||
} | } | ||||
@@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
} | } | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 4, 4); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 4, 4); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
} | } | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 4, xmax - x); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
@@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | |||||
static void gemm_s8x8x16_8x8_pack_B_n( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
@@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr_interleave); | |||||
interleave_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave); | |||||
outptr += ksize8; | outptr += ksize8; | ||||
} | } | ||||
@@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr_interleave, 4, 4); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave, 4, 4); | |||||
outptr += ksize4; | outptr += ksize4; | ||||
} | } | ||||
@@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
case 6: | case 6: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 5: | case 5: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 4: | case 4: | ||||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr3 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 3: | case 3: | ||||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr4 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 2: | case 2: | ||||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr5 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr6 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
break; | break; | ||||
@@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr_interleave, 4, xmax - x); | |||||
interleave_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr_interleave, 4, xmax - x); | |||||
} | } | ||||
outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
@@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
const dt_int8* inptr, int ldin, | |||||
int y0, int ymax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_8x8_transpose_pack_B_n( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
int kmax) { | |||||
int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
constexpr int interleave4 = 32; | constexpr int interleave4 = 32; | ||||
@@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
int K = kmax - k0; | int K = kmax - k0; | ||||
for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
inptr6, inptr7, outptr); | |||||
transpose_8x8_1_b( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr); | |||||
outptr += interleave8; | outptr += interleave8; | ||||
} | } | ||||
if (K > 0) { | if (K > 0) { | ||||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
inptr7, outptr, 8, K); | |||||
transpose_8( | |||||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
outptr, 8, K); | |||||
outptr += interleave8; | outptr += interleave8; | ||||
} | } | ||||
} | } | ||||
@@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
case 2: | case 2: | ||||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr1 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 1: | case 1: | ||||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
inptr2 = zerobuff; | |||||
MEGDNN_FALLTHRU | |||||
case 0: | case 0: | ||||
inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
break; | break; | ||||
@@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 { | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
// clang-format on | // clang-format on | ||||
static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
const int8_t* packB, int K, | |||||
int16_t* output, int LDC, | |||||
bool is_first_k, | |||||
int remain_n) { | |||||
static __attribute__((noinline)) void kern_16x12( | |||||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int remain_n) { | |||||
K /= 4; | K /= 4; | ||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
"101:\n" | "101:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
[remain_n] "+r"(remain_n) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"x8", "x9", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
"memory"); | |||||
#undef STORE_C | #undef STORE_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
// clang-format on | // clang-format on | ||||
static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
const int8_t* packB, int K, | |||||
int16_t* output, int LDC, | |||||
bool is_first_k, int remain_n) { | |||||
static __attribute__((noinline)) void kern_8x12( | |||||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int remain_n) { | |||||
K /= 4; | K /= 4; | ||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
"101:\n" | "101:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
[remain_n] "+r"(remain_n) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||||
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||||
#undef STORE_C | #undef STORE_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
// clang-format on | // clang-format on | ||||
static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||||
const int8_t* packB, int K, | |||||
int16_t* output, int LDC, | |||||
bool is_first_k, int remain_n) { | |||||
static __attribute__((noinline)) void kern_4x12( | |||||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int remain_n) { | |||||
K /= 4; | K /= 4; | ||||
const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
@@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||||
"6:\n" STORE_C | "6:\n" STORE_C | ||||
"101:\n" | "101:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
[remain_n] "+r"(remain_n) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
"memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||||
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||||
#undef STORE_C | #undef STORE_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
} | } | ||||
static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||||
const dt_int8* inptr, int ldin, | |||||
int m0, int mmax, int k0, int kmax) { | |||||
static void gemm_s8x8x16_mk4_16x12_pack_A( | |||||
dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_m = 16; | constexpr int pack_m = 16; | ||||
@@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int n0, int nmax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_mk4_16x12_pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_n = 12; | constexpr int pack_n = 12; | ||||
@@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 { | |||||
*/ | */ | ||||
// clang-format on | // clang-format on | ||||
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool, int remain_n) { | |||||
static inline void kern_4x4( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool, | |||||
int remain_n) { | |||||
K = div_ceil(K, 8); | K = div_ceil(K, 8); | ||||
int oddk = (K & 1); | int oddk = (K & 1); | ||||
K = ((K + 1) / 2) - 1; | K = ((K + 1) / 2) - 1; | ||||
@@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
"7:\n" STORE_C | "7:\n" STORE_C | ||||
"101:\n" | "101:\n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[remain_n] "+r"(remain_n) | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk), | |||||
[LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"x8", "x9", "x10", "cc", "memory"); | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
"memory"); | |||||
#undef STORE_C | #undef STORE_C | ||||
#undef STORE_LINE | #undef STORE_LINE | ||||
@@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||||
vst1_s8(outptr + 3 * 8, in0.val[3]); | vst1_s8(outptr + 3 * 8, in0.val[3]); | ||||
} | } | ||||
static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||||
dt_int8* outptr) { | |||||
static inline void interleve_8x4_b( | |||||
const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) { | |||||
int8x16_t in0 = vld1q_s8(inptr); | int8x16_t in0 = vld1q_s8(inptr); | ||||
int8x16_t in1 = vld1q_s8(inptr2); | int8x16_t in1 = vld1q_s8(inptr2); | ||||
int32x4x2_t in_x2 = { | |||||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | ||||
} | } | ||||
static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | ||||
int8x16_t in0 = vld1q_s8(inptr); | int8x16_t in0 = vld1q_s8(inptr); | ||||
int8x16_t in1 = vdupq_n_s8(0); | int8x16_t in1 = vdupq_n_s8(0); | ||||
int32x4x2_t in_x2 = { | |||||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | ||||
} | } | ||||
static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||||
int ldin, int m0, int mmax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_mk4_4x4x8_pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) { | |||||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_m = 4; | constexpr int pack_m = 4; | ||||
@@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
} | } | ||||
static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int n0, int nmax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_mk4_4x4x8_pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_n = 4; | constexpr int pack_n = 4; | ||||
@@ -18,7 +18,6 @@ namespace megdnn { | |||||
namespace aarch64 { | namespace aarch64 { | ||||
namespace matmul_mk4_8x8x8 { | namespace matmul_mk4_8x8x8 { | ||||
/** | /** | ||||
* Overview of register layout: | * Overview of register layout: | ||||
* | * | ||||
@@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 { | |||||
* | v16 | | v28 | | * | v16 | | v28 | | ||||
* | v17 | | v29 | | * | v17 | | v29 | | ||||
* | v16 | | v30 | | * | v16 | | v30 | | ||||
* | v17 | | v31 | | |||||
* | v17 | | v31 | | |||||
* +--------+ - - - - +---------------------------------+ | * +--------+ - - - - +---------------------------------+ | ||||
* | * | ||||
* Accumulator | * Accumulator | ||||
*/ | */ | ||||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_8x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packB;//packA; | |||||
const int8_t* b_ptr = packA;//packB; | |||||
const int8_t* a_ptr = packB; // packA; | |||||
const int8_t* b_ptr = packA; // packB; | |||||
// clang-format off | // clang-format off | ||||
#define LOAD_C_8 \ | #define LOAD_C_8 \ | ||||
"ld1 {v0.8h}, [x0], #16\n" \ | "ld1 {v0.8h}, [x0], #16\n" \ | ||||
@@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | ||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | ||||
"v29", "v30", "v31"); | "v29", "v30", "v31"); | ||||
// clang-format on | |||||
// clang-format on | |||||
} | } | ||||
static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_8x8_remain( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packB; | const int8_t* a_ptr = packB; | ||||
const int8_t* b_ptr = packA; | const int8_t* b_ptr = packA; | ||||
// clang-format off | |||||
// clang-format off | |||||
register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
asm volatile( | asm volatile( | ||||
"add x1, x0, %x[LDC]\n" | "add x1, x0, %x[LDC]\n" | ||||
@@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
"cbnz %w[K], 1b\n" | "cbnz %w[K], 1b\n" | ||||
"cmp %w[is_first_k], #1\n" | "cmp %w[is_first_k], #1\n" | ||||
"beq 2f\n" | |||||
"beq 2f\n" | |||||
"cmp %x[m_remain], #8 \n" | "cmp %x[m_remain], #8 \n" | ||||
"beq 8f \n" | "beq 8f \n" | ||||
"cmp %x[m_remain], #4 \n" | "cmp %x[m_remain], #4 \n" | ||||
@@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
"zip2 v15.2d, v30.2d, v31.2d \n" | "zip2 v15.2d, v30.2d, v31.2d \n" | ||||
"add v6.8h, v6.8h, v13.8h \n" | "add v6.8h, v6.8h, v13.8h \n" | ||||
"add v7.8h, v7.8h, v15.8h \n" | "add v7.8h, v7.8h, v15.8h \n" | ||||
//save to memory | |||||
// save to memory | |||||
"cmp %x[m_remain], #8 \n" | "cmp %x[m_remain], #8 \n" | ||||
"beq 4f \n" | "beq 4f \n" | ||||
"cmp %x[m_remain], #4 \n" | "cmp %x[m_remain], #4 \n" | ||||
@@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
"b 1000f \n" | "b 1000f \n" | ||||
"1000: \n" | "1000: \n" | ||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
: | : | ||||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
[ n_remain ] "+r"(n_remain) | |||||
: | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31"); | |||||
// clang-format on | |||||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
// clang-format on | |||||
#undef LOAD_C_8 | #undef LOAD_C_8 | ||||
#undef STORE_C_8 | #undef STORE_C_8 | ||||
} | } | ||||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x8( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packB;//packA; | |||||
const int8_t* b_ptr = packA;//packB; | |||||
const int8_t* a_ptr = packB; // packA; | |||||
const int8_t* b_ptr = packA; // packB; | |||||
// clang-format off | // clang-format off | ||||
#define LOAD_C_4 \ | #define LOAD_C_4 \ | ||||
"ld1 {v0.8h}, [x0], #16\n" \ | "ld1 {v0.8h}, [x0], #16\n" \ | ||||
@@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef LOAD_C_4 | #undef LOAD_C_4 | ||||
#undef STORE_C_4 | #undef STORE_C_4 | ||||
} | } | ||||
static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
int n_remain) { | |||||
static void kern_4x8_remain( | |||||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
bool is_first_k, int m_remain, int n_remain) { | |||||
K /= 8; | K /= 8; | ||||
LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
const int8_t* a_ptr = packB;//packA; | |||||
const int8_t* b_ptr = packA;//packB; | |||||
// clang-format off | |||||
const int8_t* a_ptr = packB; // packA; | |||||
const int8_t* b_ptr = packA; // packB; | |||||
// clang-format off | |||||
register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
asm volatile( | asm volatile( | ||||
@@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C_4 | #undef STORE_C_4 | ||||
} | } | ||||
//! pack to icxoc | //! pack to icxoc | ||||
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | ||||
//! if M K is not times of 8,pack 0 instead | |||||
static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
const dt_int8* inptr, int ldin, | |||||
int m0, int mmax, int k0, int kmax) { | |||||
//! if M K is not times of 8,pack 0 instead | |||||
static void gemm_s8x8x16_mk4_8x8x8_pack_A( | |||||
dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_m = 8; | constexpr int pack_m = 8; | ||||
@@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
int k_idx = k0; | int k_idx = k0; | ||||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||||
} | } | ||||
if (k_idx < kmax) { | if (k_idx < kmax) { | ||||
@@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
int k_idx = k0; | int k_idx = k0; | ||||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
inptr1 = zerobuff; | inptr1 = zerobuff; | ||||
interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||||
interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||||
} | } | ||||
if (k_idx < kmax) { | if (k_idx < kmax) { | ||||
@@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
} | } | ||||
//! pack to nxic | //! pack to nxic | ||||
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | ||||
static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int n0, int nmax, int k0, | |||||
int kmax) { | |||||
static void gemm_s8x8x16_mk4_8x8x8_pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
constexpr int pack_n = 8; | constexpr int pack_n = 8; | ||||
@@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
int8_t tmpbuff0[pack_n * pack_size] = {0}; | int8_t tmpbuff0[pack_n * pack_size] = {0}; | ||||
int8_t tmpbuff1[pack_n * pack_size] = {0}; | int8_t tmpbuff1[pack_n * pack_size] = {0}; | ||||
int8_t zerobuff[pack_n * pack_size] = {0}; | int8_t zerobuff[pack_n * pack_size] = {0}; | ||||
const int ksize = round_up<int>((kmax - k0),8); | |||||
const int ksize = round_up<int>((kmax - k0), 8); | |||||
const int nsize = nmax - n0; | const int nsize = nmax - n0; | ||||
const int n_end = nsize / pack_n * pack_n + n0; | const int n_end = nsize / pack_n * pack_n + n0; | ||||
const int remain_n = nsize % pack_n; | const int remain_n = nsize % pack_n; | ||||
int output_stride = ksize * pack_n; | int output_stride = ksize * pack_n; | ||||
int8_t* outptr_base = out; | int8_t* outptr_base = out; | ||||
int k_idx = k0; | int k_idx = k0; | ||||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | ||||
const int8_t* inptr1 = inptr0 + ldin; | const int8_t* inptr1 = inptr0 + ldin; | ||||
prefetch_3x(inptr0); | prefetch_3x(inptr0); | ||||
@@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | ||||
transpose_8x8_mk4_b(inptr0, inptr1, outptr); | transpose_8x8_mk4_b(inptr0, inptr1, outptr); | ||||
outptr += output_stride; | |||||
outptr += output_stride; | |||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | ||||
@@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
outptr_base += pack_n * pack_k; | outptr_base += pack_n * pack_k; | ||||
} | } | ||||
if(k_idx < kmax){ | |||||
if (k_idx < kmax) { | |||||
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | ||||
const int8_t* inptr1 = nullptr; | const int8_t* inptr1 = nullptr; | ||||
prefetch_3x(inptr0); | prefetch_3x(inptr0); | ||||
@@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
} | } | ||||
} | } | ||||
} // namespace matmul_mk4_16x12x4_a53 | |||||
} // namespace matmul_mk4_8x8x8 | |||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -10,13 +10,13 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -28,39 +28,35 @@ using namespace aarch64::matmul; | |||||
// ===========================gemm_s8x8x16_4x4================================== | // ===========================gemm_s8x8x16_4x4================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | ||||
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8x8x16_8x8::pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||||
ymax, k0, kmax); | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n( | |||||
out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8x8x16_8x8::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s8x8x16_8x8::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x8x8::kern_8x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | size_t n = 0; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
matmul_8x8x8::kern_4x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K8; | cur_packB += K8; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
matmul_8x8x8::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
// ===========================gemm_s8x8x16_4x4================================== | // ===========================gemm_s8x8x16_4x4================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | ||||
void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8x8x16_4x4::pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
void gemm_s8x8x16_4x4::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool transpose) const { | |||||
if (transpose) { | if (transpose) { | ||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} else { | } else { | ||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
} | } | ||||
void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
C_dtype.name()); | |||||
void gemm_s8x8x16_4x4::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
C_dtype.enumv() == DTypeEnum::Int16), | |||||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
@@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
matmul_4x4x16::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
B_INTERLEAVE); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
matmul_4x4x16::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
matmul_4x4x16::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, A_INTERLEAVE), | |||||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
// ===========================gemm_s8x8x16_mk4_16x12================================== | // ===========================gemm_s8x8x16_mk4_16x12================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | ||||
void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||||
ymax, k0, kmax); | |||||
void gemm_s8x8x16_mk4_16x12_a53::pack_A( | |||||
dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A( | |||||
out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
void gemm_s8x8x16_mk4_16x12_a53::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
const dt_int8* packB, size_t M, size_t N, | |||||
size_t K, dt_int16* C, size_t LDC, | |||||
bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
void gemm_s8x8x16_mk4_16x12_a53::kern( | |||||
const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
@@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
size_t n_idx = 0; | size_t n_idx = 0; | ||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
matmul_mk4_16x12x4_a53::kern_16x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
matmul_mk4_16x12x4_a53::kern_16x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
@@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
size_t n_idx = 0; | size_t n_idx = 0; | ||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
matmul_mk4_16x12x4_a53::kern_8x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
matmul_mk4_16x12x4_a53::kern_8x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
@@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
size_t n_idx = 0; | size_t n_idx = 0; | ||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
matmul_mk4_16x12x4_a53::kern_4x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
matmul_mk4_16x12x4_a53::kern_4x12( | |||||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
@@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
// ===========================gemm_s8x8x16_mk4_4x4_a72================================== | // ===========================gemm_s8x8x16_mk4_4x4_a72================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | ||||
void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||||
k0, kmax); | |||||
void gemm_s8x8x16_mk4_4x4_a72::pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A( | |||||
out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
void gemm_s8x8x16_mk4_4x4_a72::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B( | |||||
out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, | |||||
const dt_int16*, dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
void gemm_s8x8x16_mk4_4x4_a72::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
@@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | ||||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
matmul_mk4_4x4x8_a72::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += pack_n * packed_k; | cur_packB += pack_n * packed_k; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
matmul_mk4_4x4x8_a72::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += pack_n * packed_k; | cur_packB += pack_n * packed_k; | ||||
} | } | ||||
@@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
// ===========================gemm_s8x8x16_mk4_8x8x8================================== | // ===========================gemm_s8x8x16_mk4_8x8x8================================== | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | ||||
void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||||
ymax, k0, kmax); | |||||
void gemm_s8x8x16_mk4_8x8x8::pack_A( | |||||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
void gemm_s8x8x16_mk4_8x8x8::pack_B( | |||||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||||
} | } | ||||
void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
void gemm_s8x8x16_mk4_8x8x8::kern( | |||||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
megdnn_assert( | |||||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
@@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n_idx = 0; | size_t n_idx = 0; | ||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_m, pack_n); | |||||
matmul_mk4_8x8x8::kern_8x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += KSIZE8; | cur_packB += KSIZE8; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_m, remain_n); | |||||
matmul_mk4_8x8x8::kern_8x8_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += KSIZE8; | cur_packB += KSIZE8; | ||||
} | } | ||||
@@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t n_idx = 0; | size_t n_idx = 0; | ||||
const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, 4, pack_n); | |||||
matmul_mk4_8x8x8::kern_4x8( | |||||
packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n); | |||||
output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||
if (remain_n > 0) { | if (remain_n > 0) { | ||||
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||||
is_first_k, 4, remain_n); | |||||
matmul_mk4_8x8x8::kern_4x8_remain( | |||||
packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n); | |||||
output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
} | } | ||||