@@ -23,9 +23,9 @@ | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" | |||
#pragma GCC diagnostic ignored "-Wsign-compare" | |||
#include <hip/hip_runtime_api.h> | |||
#include <hip/hip_runtime.h> | |||
#include <hip/hip_fp16.h> | |||
#include <hip/hip_runtime.h> | |||
#include <hip/hip_runtime_api.h> | |||
#pragma GCC diagnostic pop | |||
#if !defined(__HIP_PLATFORM_HCC__) | |||
@@ -11,10 +11,10 @@ | |||
#pragma once | |||
#include "megdnn/thin/function.h" | |||
#include "megcore_cdefs.h" | |||
#include <cstddef> | |||
#include <memory> | |||
#include "megcore_cdefs.h" | |||
#include "megdnn/thin/function.h" | |||
#include "megdnn/internal/visibility_prologue.h" | |||
@@ -26,36 +26,35 @@ namespace megcore { | |||
* the caller thread immediately. | |||
*/ | |||
class CPUDispatcher { | |||
public: | |||
using Task = megdnn::thin_function<void()>; | |||
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||
virtual ~CPUDispatcher() noexcept; | |||
/*! | |||
* \brief dispatch a task on the computing thread | |||
* \param task the task that would be moved away | |||
*/ | |||
virtual void dispatch(Task&& task) = 0; | |||
/*! | |||
* \brief dispatch a multithreading task on the computing thread | |||
* \param task the task would be moved away | |||
* \param parallelism the parallelism of the task. | |||
*/ | |||
virtual void dispatch(MultiThreadingTask&& task, | |||
size_t parallelism) = 0; | |||
/*! | |||
* \brief synchronize the calling thread with the computing thread | |||
*/ | |||
virtual void sync() = 0; | |||
/*! | |||
* \brief the computing thread number. | |||
*/ | |||
virtual size_t nr_threads() = 0; | |||
public: | |||
using Task = megdnn::thin_function<void()>; | |||
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||
virtual ~CPUDispatcher() noexcept; | |||
/*! | |||
* \brief dispatch a task on the computing thread | |||
* \param task the task that would be moved away | |||
*/ | |||
virtual void dispatch(Task&& task) = 0; | |||
/*! | |||
* \brief dispatch a multithreading task on the computing thread | |||
* \param task the task would be moved away | |||
* \param parallelism the parallelism of the task. | |||
*/ | |||
virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0; | |||
/*! | |||
* \brief synchronize the calling thread with the computing thread | |||
*/ | |||
virtual void sync() = 0; | |||
/*! | |||
* \brief the computing thread number. | |||
*/ | |||
virtual size_t nr_threads() = 0; | |||
}; | |||
} // namespace megcore | |||
} // namespace megcore | |||
using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||
@@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||
* \brief Layer 1: device handle | |||
*/ | |||
struct megcoreDeviceContext; | |||
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; | |||
typedef struct megcoreDeviceContext* megcoreDeviceHandle_t; | |||
megcoreStatus_t megcoreCreateDeviceHandle( | |||
megcoreDeviceHandle_t *handle, | |||
megcorePlatform_t platform, | |||
int deviceID = -1, | |||
megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreDestroyDeviceHandle( | |||
megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, | |||
megcorePlatform_t *platform); | |||
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, | |||
int *deviceID); | |||
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, | |||
size_t *memAlignmentInBytes); | |||
megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreGetPlatform( | |||
megcoreDeviceHandle_t handle, megcorePlatform_t* platform); | |||
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID); | |||
megcoreStatus_t megcoreGetMemAlignment( | |||
megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes); | |||
megcoreStatus_t megcoreGetDeviceFlags( | |||
megcoreDeviceHandle_t handle, | |||
unsigned int *flags); | |||
megcoreDeviceHandle_t handle, unsigned int* flags); | |||
megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, | |||
void **devPtr, size_t sizeInBytes); | |||
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, | |||
void *devPtr); | |||
megcoreStatus_t megcoreMalloc( | |||
megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes); | |||
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr); | |||
/** | |||
* \brief Layer 2: computing handle | |||
*/ | |||
struct megcoreComputingContext; | |||
typedef struct megcoreComputingContext *megcoreComputingHandle_t; | |||
typedef struct megcoreComputingContext* megcoreComputingHandle_t; | |||
megcoreStatus_t megcoreCreateComputingHandle( | |||
megcoreComputingHandle_t *compHandle, | |||
megcoreDeviceHandle_t devHandle, | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | |||
megcoreComputingHandle_t *compHandle, | |||
megcoreDeviceHandle_t devHandle, | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreDestroyComputingHandle( | |||
megcoreComputingHandle_t handle); | |||
megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle); | |||
megcoreStatus_t megcoreGetDeviceHandle( | |||
megcoreComputingHandle_t compHandle, | |||
megcoreDeviceHandle_t *devHandle); | |||
megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle); | |||
megcoreStatus_t megcoreGetComputingFlags( | |||
megcoreComputingHandle_t handle, | |||
unsigned int *flags); | |||
megcoreComputingHandle_t handle, unsigned int* flags); | |||
MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | |||
megcoreStatus_t megcoreMemcpy( | |||
megcoreComputingHandle_t handle, | |||
void *dst, const void *src, size_t sizeInBytes, | |||
megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, | |||
megcoreMemcpyKind_t kind); | |||
megcoreStatus_t megcoreMemset( | |||
megcoreComputingHandle_t handle, | |||
void *dst, int value, size_t sizeInBytes); | |||
megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes); | |||
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | |||
/** | |||
* \brief Miscellaneous | |||
*/ | |||
const char *megcoreGetErrorName(megcoreStatus_t status); | |||
const char* megcoreGetErrorName(megcoreStatus_t status); | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, const AtlasContext& ctx); | |||
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, | |||
AtlasContext* ctx); | |||
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx); | |||
namespace atlas { | |||
//! convert acl error code to error string | |||
@@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, aclrtStream stream) { | |||
megcore::AtlasContext ctx{stream}; | |||
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, | |||
flags, ctx); | |||
return megcore::createComputingHandleWithAtlasContext( | |||
compHandle, devHandle, flags, ctx); | |||
} | |||
inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, | |||
aclrtStream* stream) { | |||
inline megcoreStatus_t megcoreGetACLStream( | |||
megcoreComputingHandle_t handle, aclrtStream* stream) { | |||
megcore::AtlasContext ctx; | |||
auto ret = megcore::getAtlasContext(handle, &ctx); | |||
*stream = ctx.stream; | |||
@@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, const CambriconContext& ctx); | |||
megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, | |||
CambriconContext* ctx); | |||
megcoreStatus_t getCambriconContext( | |||
megcoreComputingHandle_t handle, CambriconContext* ctx); | |||
} // namespace megcore | |||
@@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen | |||
@@ -40,7 +40,6 @@ typedef enum { | |||
megcoreErrorInternalError = 5, | |||
} megcoreStatus_t; | |||
/** | |||
* \brief Memcpy kind | |||
*/ | |||
@@ -70,6 +69,6 @@ struct AsyncErrorInfo { | |||
char msg[228]; | |||
int msg_args[4]; | |||
}; | |||
} // namespace megcore | |||
} // namespace megcore | |||
// vim: syntax=cpp.doxygen |
@@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, const CudaContext& ctx); | |||
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, | |||
CudaContext* ctx); | |||
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx); | |||
} // namespace megcore | |||
@@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( | |||
unsigned int flags, cudaStream_t stream) { | |||
megcore::CudaContext ctx; | |||
ctx.stream = stream; | |||
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, | |||
flags, ctx); | |||
return megcore::createComputingHandleWithCUDAContext( | |||
compHandle, devHandle, flags, ctx); | |||
} | |||
static inline megcoreStatus_t megcoreGetCUDAStream( | |||
@@ -23,7 +23,9 @@ struct ROCMContext { | |||
hipStream_t stream = nullptr; | |||
static std::atomic_bool sm_miopen_algo_search; | |||
static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } | |||
static inline bool enable_miopen_algo_search() { | |||
return sm_miopen_algo_search.load(); | |||
} | |||
static inline void enable_miopen_algo_search(bool enable_algo_search) { | |||
sm_miopen_algo_search.store(enable_algo_search); | |||
} | |||
@@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, const ROCMContext& ctx); | |||
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, | |||
ROCMContext* ctx); | |||
megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx); | |||
// Set MIOpen algo search enabled or disabled | |||
megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | |||
@@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( | |||
unsigned int flags, hipStream_t stream) { | |||
megcore::ROCMContext ctx; | |||
ctx.stream = stream; | |||
return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, | |||
flags, ctx); | |||
return megcore::createComputingHandleWithROCMContext( | |||
compHandle, devHandle, flags, ctx); | |||
} | |||
static inline megcoreStatus_t megcoreGetROCMStream( | |||
@@ -10,7 +10,7 @@ | |||
*/ | |||
#pragma once | |||
#include "megdnn/version.h" | |||
#include "megdnn/oprs.h" | |||
#include "megdnn/version.h" | |||
// vim: syntax=cpp.doxygen |
@@ -14,20 +14,20 @@ | |||
#include "megdnn/config/config.h" | |||
#if defined(__GNUC__) || defined(__clang__) | |||
#if !defined (__clang__) | |||
// gcc specific | |||
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||
#if GCC_VERSION < 40800 | |||
#error "GCC version should be at least 4.8.0." | |||
#endif // GCC_VERSION < 40800 | |||
#endif // !defined(__clang__) | |||
#if !defined(__clang__) | |||
// gcc specific | |||
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||
#if GCC_VERSION < 40800 | |||
#error "GCC version should be at least 4.8.0." | |||
#endif // GCC_VERSION < 40800 | |||
#endif // !defined(__clang__) | |||
#ifndef megdnn_trap | |||
#define megdnn_trap() __builtin_trap() | |||
#endif | |||
#ifndef megdnn_trap | |||
#define megdnn_trap() __builtin_trap() | |||
#endif | |||
#define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||
#define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||
#if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | |||
//! Thumb2 limit code length | |||
@@ -36,123 +36,122 @@ | |||
#define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | |||
#endif | |||
#define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||
#define MEGDNN_PACKED __attribute__((packed)) | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_STATIC_ASSERT static_assert | |||
#define MEGDNN_FINAL final | |||
#define MEGDNN_NORETURN __attribute__((noreturn)) | |||
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#if defined(__clang_major__) && (__clang_major__ >= 7) | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#else | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||
#endif | |||
#define MEGDNN_NOINLINE __attribute__((noinline)) | |||
#define megdnn_isatty(x) isatty(x) | |||
#define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||
#define MEGDNN_PACKED __attribute__((packed)) | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_STATIC_ASSERT static_assert | |||
#define MEGDNN_FINAL final | |||
#define MEGDNN_NORETURN __attribute__((noreturn)) | |||
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#if defined(__clang_major__) && (__clang_major__ >= 7) | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#else | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||
#endif | |||
#define MEGDNN_NOINLINE __attribute__((noinline)) | |||
#define megdnn_isatty(x) isatty(x) | |||
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | |||
#ifndef megdnn_trap | |||
#define megdnn_trap() __debugbreak() | |||
#endif | |||
#define megdnn_likely(v) (bool(v)) | |||
#define megdnn_likely(v) (bool(v)) | |||
#define megdnn_unlikely(v) (bool(v)) | |||
#define MEGDNN_DEPRECATED | |||
#define MEGDNN_PACKED | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_STATIC_ASSERT static_assert | |||
#define MEGDNN_FINAL final | |||
#define MEGDNN_FINAL final | |||
#if defined(_MSC_VER) | |||
#define MEGDNN_NORETURN __declspec(noreturn) | |||
#define MEGDNN_NOINLINE __declspec(noinline) | |||
#define MEGDNN_NORETURN __declspec(noreturn) | |||
#define MEGDNN_NOINLINE __declspec(noinline) | |||
#else | |||
#define MEGDNN_NORETURN | |||
#define MEGDNN_FORCE_NOINLINE | |||
#endif // _MSC_VER | |||
#define MEGDNN_NORETURN | |||
#define MEGDNN_FORCE_NOINLINE | |||
#endif // _MSC_VER | |||
#define MEGDNN_WARN_UNUSED_RESULT | |||
#define megdnn_isatty(x) _isatty(x) | |||
#else | |||
#error "unknown compiler" | |||
#endif // __GNUC__ | |||
#error "unknown compiler" | |||
#endif // __GNUC__ | |||
// __cpp_exceptions and __cpp_rtti is referred from | |||
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | |||
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||
// similar for __GXX_RTTI | |||
// _CPPUNWIND and _CPPRTTI is used by MSVC, see | |||
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | |||
#ifndef MEGDNN_ENABLE_EXCEPTIONS | |||
#if __cpp_exceptions || __EXCEPTIONS || \ | |||
(defined(_MSC_VER) && defined(_CPPUNWIND)) | |||
#define MEGDNN_ENABLE_EXCEPTIONS 1 | |||
#else | |||
#define MEGDNN_ENABLE_EXCEPTIONS 0 | |||
#endif | |||
#if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||
#define MEGDNN_ENABLE_EXCEPTIONS 1 | |||
#else | |||
#define MEGDNN_ENABLE_EXCEPTIONS 0 | |||
#endif | |||
#endif | |||
#ifndef MEGDNN_ENABLE_RTTI | |||
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||
#define MEGDNN_ENABLE_RTTI 1 | |||
#else | |||
#define MEGDNN_ENABLE_RTTI 0 | |||
#endif | |||
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||
#define MEGDNN_ENABLE_RTTI 1 | |||
#else | |||
#define MEGDNN_ENABLE_RTTI 0 | |||
#endif | |||
#endif | |||
#ifdef __CUDACC__ | |||
#define MEGDNN_CC_CUDA 1 | |||
#undef MEGDNN_CONSTEXPR | |||
#define MEGDNN_CONSTEXPR const | |||
#define MEGDNN_CC_CUDA 1 | |||
#undef MEGDNN_CONSTEXPR | |||
#define MEGDNN_CONSTEXPR const | |||
#if defined(__CUDACC_VER_MAJOR__) | |||
#if __CUDACC_VER_MAJOR__ >= 9 | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||
#else | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) | |||
#endif | |||
#endif | |||
#define nullptr NULL | |||
#undef MEGDNN_FINAL | |||
#define MEGDNN_FINAL | |||
#define nullptr NULL | |||
#undef MEGDNN_FINAL | |||
#define MEGDNN_FINAL | |||
#elif defined(__HIPCC__) | |||
#define MEGDNN_CC_CUDA 1 | |||
#define MEGDNN_CC_CUDA 1 | |||
#else | |||
#define MEGDNN_CC_HOST 1 | |||
#endif // __CUDACC__ | |||
#define MEGDNN_CC_HOST 1 | |||
#endif // __CUDACC__ | |||
// MEGDNN_HOST and MEGDNN_DEVICE | |||
#if MEGDNN_CC_CUDA | |||
#define MEGDNN_HOST __host__ | |||
#define MEGDNN_DEVICE __device__ | |||
#define MEGDNN_HOST __host__ | |||
#define MEGDNN_DEVICE __device__ | |||
#else | |||
#define MEGDNN_HOST | |||
#define MEGDNN_DEVICE | |||
#define MEGDNN_HOST | |||
#define MEGDNN_DEVICE | |||
#endif | |||
#if MEGDNN_CC_CUDA | |||
#define MEGDNN_FORCE_INLINE __forceinline__ | |||
#define MEGDNN_FORCE_INLINE __forceinline__ | |||
#else | |||
#if __GNUC__ || __has_attribute(always_inline) | |||
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||
#else | |||
#define MEGDNN_FORCE_INLINE inline | |||
#define MEGDNN_FORCE_INLINE inline | |||
#endif | |||
#endif | |||
#if defined(_MSC_VER) || defined(WIN32) | |||
#define ATTR_ALIGNED(v) __declspec(align(v)) | |||
#define ATTR_ALIGNED(v) __declspec(align(v)) | |||
#else | |||
#define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||
#define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -16,10 +16,10 @@ | |||
#include "megdnn/internal/defs.h" | |||
#if MEGDNN_CC_HOST | |||
#include <cstdarg> | |||
#include <string> | |||
#include <type_traits> | |||
#include <vector> | |||
#include <cstdarg> | |||
#include "megdnn/thin/small_vector.h" | |||
#endif // MEGDNN_CC_HOST | |||
@@ -35,8 +35,7 @@ class ErrorHandler { | |||
protected: | |||
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | |||
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( | |||
const std::string& msg) { | |||
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) { | |||
on_megdnn_error(msg); | |||
} | |||
@@ -70,8 +69,9 @@ public: | |||
#if MEGDNN_CC_HOST | |||
enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | |||
typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, | |||
int line, const char* fmt, va_list ap); | |||
typedef void (*LogHandler)( | |||
LogLevel level, const char* file, const char* func, int line, const char* fmt, | |||
va_list ap); | |||
/*! | |||
* \brief set the callback to receive all log messages | |||
@@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { | |||
ptrdiff_t low_elem, low_byte; | |||
size_t high_elem, high_byte; | |||
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, | |||
size_t high_byte) | |||
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte) | |||
: low_elem(low_elem), | |||
low_byte(low_byte), | |||
high_elem(high_elem), | |||
@@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { | |||
TensorLayout(const TensorShape& shape, DType dtype, Format format); | |||
//! creating layout with user-specified shape and stride. | |||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
DType dtype); | |||
TensorLayout( | |||
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
DType dtype); | |||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
DType dtype, Format format); | |||
TensorLayout( | |||
const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype, | |||
Format format); | |||
/* =================== inplace modifiers =================== */ | |||
@@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { | |||
* | |||
* \throw TensorReshapeError if no stride exists for target shape. | |||
*/ | |||
TensorLayout reshape(const TensorShape& shape) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief try to reshape to another view; return whether these two shapes | |||
@@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { | |||
* \return true iff there exists target stride so this layout can be | |||
* converted to target shape and the elements can match. | |||
*/ | |||
bool try_reshape(TensorLayout& output, | |||
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
bool try_reshape(TensorLayout& output, const TensorShape& shape) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief Broadcast on dims with shape == 1 to match target *shape*. | |||
* \throw TensorReshapeError if could not be satisfied | |||
*/ | |||
TensorLayout broadcast(const TensorShape& shape) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief Collapse consecutive axes with contiguous layout together | |||
@@ -441,8 +440,7 @@ struct Workspace { | |||
Workspace() : raw_ptr(NULL), size(0) {} | |||
Workspace(dt_byte* raw_ptr_, size_t size_) | |||
: raw_ptr(raw_ptr_), size(size_) {} | |||
Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {} | |||
template <typename T> | |||
T* ptr(size_t offset_in_bytes = 0) const { | |||
@@ -467,9 +465,8 @@ public: | |||
* \param shape requested output shape | |||
* \param user_data extra user data passed in DynOutMallocPolicyCall | |||
*/ | |||
virtual TensorND alloc_output(size_t id, DType dtype, | |||
const TensorShape& shape, | |||
void* user_data) = 0; | |||
virtual TensorND alloc_output( | |||
size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0; | |||
/*! | |||
* \brief allocate workspace memory | |||
@@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { | |||
*/ | |||
template <typename T = void, typename elem = T> | |||
T* alloc_workspace(size_t nr_elem) { | |||
using real_elem = | |||
typename std::conditional<std::is_same<elem, void>::value, | |||
uint8_t, elem>::type; | |||
return static_cast<T*>(policy->alloc_workspace( | |||
nr_elem * sizeof(real_elem), user_data)); | |||
using real_elem = typename std::conditional< | |||
std::is_same<elem, void>::value, uint8_t, elem>::type; | |||
return static_cast<T*>( | |||
policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data)); | |||
} | |||
void free_workspace(void* ptr) { | |||
return policy->free_workspace(ptr, user_data); | |||
} | |||
void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); } | |||
}; | |||
template <typename T> | |||
class EnumClassBit { | |||
std::underlying_type_t<T> m_val; | |||
@@ -528,8 +521,7 @@ class EnumClassBit { | |||
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | |||
public: | |||
constexpr EnumClassBit(T v) | |||
: m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||
constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||
constexpr operator T() const { return static_cast<T>(m_val); } | |||
@@ -542,7 +534,7 @@ public: | |||
DEF_OPR(&) | |||
DEF_OPR(|) | |||
DEF_OPR (^) | |||
DEF_OPR(^) | |||
constexpr EnumClassBit operator~() const { return ~m_val; } | |||
@@ -553,14 +545,13 @@ public: | |||
} // namespace megdnn | |||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||
return ::megdnn::EnumClassBit<cls>(x) \ | |||
op ::megdnn::EnumClassBit<cls>(y); \ | |||
} \ | |||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||
return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \ | |||
} \ | |||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||
} | |||
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||
@@ -14,14 +14,14 @@ | |||
#include "megbrain_build_config.h" | |||
#if MGB_ENABLE_GETENV | |||
#define MGB_GETENV ::std::getenv | |||
#define MGB_GETENV ::std::getenv | |||
#else | |||
#define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||
#define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||
#endif | |||
#ifdef WIN32 | |||
#define unsetenv(_name) _putenv_s(_name, ""); | |||
#define setenv(name,value,overwrite) _putenv_s(name,value) | |||
#define unsetenv(_name) _putenv_s(_name, ""); | |||
#define setenv(name, value, overwrite) _putenv_s(name, value) | |||
#endif | |||
namespace megdnn { | |||
@@ -32,8 +32,7 @@ namespace megdnn { | |||
*/ | |||
template <class Opr, typename... Args> | |||
bool has_available_algo(Opr* opr, Args&&... args) { | |||
const typename Opr::AlgoBase::SizeArgs size_args( | |||
opr, std::forward<Args>(args)...); | |||
const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||
for (auto i : Opr::algo_pack().all_algos) { | |||
if (i->is_available(size_args)) { | |||
return true; | |||
@@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { | |||
return false; | |||
} | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -17,11 +17,11 @@ | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, | |||
int device_id = -1); | |||
cudaStream_t get_cuda_stream(Handle *handle); | |||
std::unique_ptr<Handle> make_cuda_handle_with_stream( | |||
cudaStream_t stream, int device_id = -1); | |||
cudaStream_t get_cuda_stream(Handle* handle); | |||
} // namespace megdnn | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -3,17 +3,22 @@ | |||
* | |||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
* | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
* Software is furnished to do so, subject to the following conditions: | |||
* | |||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||
* software and associated documentation files (the "Software"), to deal in the Software | |||
* without restriction, including without limitation the rights to use, copy, modify, | |||
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||
* permit persons to whom the Software is furnished to do so, subject to the following | |||
* conditions: | |||
* | |||
* The above copyright notice and this permission notice shall be included in all copies | |||
* or substantial portions of the Software. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* | |||
* Version 1.11.0 | |||
* \file | |||
@@ -41,8 +46,8 @@ | |||
#undef HALF_NOEXCEPT | |||
#undef HALF_NOTHROW | |||
#ifdef HALF_POP_WARNINGS | |||
#pragma warning(pop) | |||
#undef HALF_POP_WARNINGS | |||
#pragma warning(pop) | |||
#undef HALF_POP_WARNINGS | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -3,17 +3,22 @@ | |||
* | |||
* Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
* | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
* files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
* modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
* Software is furnished to do so, subject to the following conditions: | |||
* Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||
* software and associated documentation files (the "Software"), to deal in the Software | |||
* without restriction, including without limitation the rights to use, copy, modify, | |||
* merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||
* permit persons to whom the Software is furnished to do so, subject to the following | |||
* conditions: | |||
* | |||
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
* The above copyright notice and this permission notice shall be included in all copies | |||
* or substantial portions of the Software. | |||
* | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
* WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||
* INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||
* CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||
* OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
* | |||
* Version 1.11.0 | |||
* \file | |||
@@ -39,166 +44,164 @@ | |||
#include "megdnn/arch.h" | |||
/// Combined gcc version number. | |||
#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) | |||
//check C++11 language features | |||
#if defined(__clang__) //clang | |||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
/*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif*/ | |||
#elif defined(__GNUC__) //gcc | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#endif | |||
#elif defined(_MSC_VER) //Visual C++ | |||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#define HALF_POP_WARNINGS 1 | |||
#pragma warning(push) | |||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||
#pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
// check C++11 language features | |||
#if defined(__clang__) // clang | |||
#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ | |||
!defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
/*#elif defined(__INTEL_COMPILER) | |||
//Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= | |||
1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define | |||
HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 && | |||
!defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define | |||
HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 && | |||
!defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define | |||
HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/ | |||
#elif defined(__GNUC__) // gcc | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
#define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
#define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
#define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
#endif | |||
#if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#endif | |||
#elif defined(_MSC_VER) // Visual C++ | |||
#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
#endif | |||
#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
#define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
#endif | |||
#define HALF_POP_WARNINGS 1 | |||
#pragma warning(push) | |||
//! 4521 and 4522 is multiple copy/assigment operator specified | |||
#pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if, | |||
// negative unsigned | |||
#endif | |||
//check C++11 library features | |||
// check C++11 library features | |||
#include <utility> | |||
#if defined(_LIBCPP_VERSION) //libc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#elif defined(__GLIBCXX__) //libstdc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifdef __clang__ | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#else | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
#if _CPPLIB_VER >= 520 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#if _CPPLIB_VER >= 610 | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#endif | |||
#if defined(_LIBCPP_VERSION) // libc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#elif defined(__GLIBCXX__) // libstdc++ | |||
#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
#ifdef __clang__ | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#else | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ | |||
#if _CPPLIB_VER >= 520 | |||
#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_CSTDINT | |||
#define HALF_ENABLE_CPP11_CSTDINT 1 | |||
#endif | |||
#ifndef HALF_ENABLE_CPP11_HASH | |||
#define HALF_ENABLE_CPP11_HASH 1 | |||
#endif | |||
#endif | |||
#if _CPPLIB_VER >= 610 | |||
#ifndef HALF_ENABLE_CPP11_CMATH | |||
#define HALF_ENABLE_CPP11_CMATH 1 | |||
#endif | |||
#endif | |||
#endif | |||
#undef HALF_GNUC_VERSION | |||
//support constexpr | |||
// support constexpr | |||
#if HALF_ENABLE_CPP11_CONSTEXPR | |||
#define HALF_CONSTEXPR constexpr | |||
#define HALF_CONSTEXPR_CONST constexpr | |||
#define HALF_CONSTEXPR constexpr | |||
#define HALF_CONSTEXPR_CONST constexpr | |||
#else | |||
#define HALF_CONSTEXPR | |||
#define HALF_CONSTEXPR_CONST const | |||
#define HALF_CONSTEXPR | |||
#define HALF_CONSTEXPR_CONST const | |||
#endif | |||
//support noexcept | |||
// support noexcept | |||
#if HALF_ENABLE_CPP11_NOEXCEPT | |||
#define HALF_NOEXCEPT noexcept | |||
#define HALF_NOTHROW noexcept | |||
#define HALF_NOEXCEPT noexcept | |||
#define HALF_NOTHROW noexcept | |||
#else | |||
#define HALF_NOEXCEPT | |||
#define HALF_NOTHROW throw() | |||
#define HALF_NOEXCEPT | |||
#define HALF_NOTHROW throw() | |||
#endif | |||
#include <algorithm> | |||
#include <limits> | |||
#include <climits> | |||
#include <cmath> | |||
#include <cstring> | |||
#include <ostream> | |||
#include <istream> | |||
#include <limits> | |||
#include <ostream> | |||
#if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
#include <type_traits> | |||
#include <type_traits> | |||
#endif | |||
#if HALF_ENABLE_CPP11_CSTDINT | |||
#include <cstdint> | |||
#include <cstdint> | |||
#endif | |||
#if HALF_ENABLE_CPP11_HASH | |||
#include <functional> | |||
#include <functional> | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -12,8 +12,8 @@ | |||
#pragma once | |||
#include "megcore.h" | |||
#include "megdnn/config/config.h" | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/config/config.h" | |||
#include <functional> | |||
#include <memory> | |||
@@ -24,150 +24,147 @@ namespace megdnn { | |||
class OperatorBase; | |||
class Handle { | |||
public: | |||
enum class HandleType { | |||
NAIVE = 0, | |||
FALLBACK = 1, | |||
X86 = 2, | |||
ARM_COMMON = 3, | |||
ARMV7 = 4, | |||
AARCH64 = 5, | |||
CUDA = 6, | |||
ROCM = 11, | |||
ATLAS = 13, | |||
CAMBRICON = 12, | |||
}; | |||
//! Device vendor | |||
enum class HandleVendorType : uint32_t { | |||
NOT_SPEC = 0, | |||
MALI = 1, | |||
ADRENO = 2, | |||
CUDA = 3, | |||
INTEL = 4, | |||
POWERVR = 5, | |||
AMD = 6, | |||
}; | |||
protected: | |||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
public: | |||
/** | |||
* \brief Create a MegDNN handle from a MegCore Computing handle. | |||
* | |||
* \param[in] computing_handle MegCore computing handle. Please note | |||
* that computing_handle would not be released when this Handle is | |||
* destructed | |||
* \param[in] debug_level | |||
* Applicable for CPU computing handle. | |||
* 0 means taking the fastest possible code path; it may contains | |||
* platform-specific instructions such as SSE for x86_64 or NEON for | |||
* armv7v7. | |||
* 1 means taking the fastest possible code path without | |||
* platform-specific instructions in C++ code. Note that the compiled | |||
* binary file still contains platform-specific codes. | |||
* 2 means taking the naive code path. Performance is severely | |||
* hampered, but it is less error-prone since the internal | |||
* implementation is rather straightforward. | |||
* | |||
* **Debug level 1 and 2 should not be used in productions.** | |||
*/ | |||
static std::unique_ptr<Handle> make( | |||
megcoreComputingHandle_t computing_handle, | |||
int debug_level = 0); | |||
public: | |||
enum class HandleType { | |||
NAIVE = 0, | |||
FALLBACK = 1, | |||
X86 = 2, | |||
ARM_COMMON = 3, | |||
ARMV7 = 4, | |||
AARCH64 = 5, | |||
CUDA = 6, | |||
ROCM = 11, | |||
ATLAS = 13, | |||
CAMBRICON = 12, | |||
}; | |||
//! Device vendor | |||
enum class HandleVendorType : uint32_t { | |||
NOT_SPEC = 0, | |||
MALI = 1, | |||
ADRENO = 2, | |||
CUDA = 3, | |||
INTEL = 4, | |||
POWERVR = 5, | |||
AMD = 6, | |||
}; | |||
protected: | |||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
public: | |||
/** | |||
* \brief Create a MegDNN handle from a MegCore Computing handle. | |||
* | |||
* \param[in] computing_handle MegCore computing handle. Please note | |||
* that computing_handle would not be released when this Handle is | |||
* destructed | |||
* \param[in] debug_level | |||
* Applicable for CPU computing handle. | |||
* 0 means taking the fastest possible code path; it may contains | |||
* platform-specific instructions such as SSE for x86_64 or NEON for | |||
* armv7v7. | |||
* 1 means taking the fastest possible code path without | |||
* platform-specific instructions in C++ code. Note that the compiled | |||
* binary file still contains platform-specific codes. | |||
* 2 means taking the naive code path. Performance is severely | |||
* hampered, but it is less error-prone since the internal | |||
* implementation is rather straightforward. | |||
* | |||
* **Debug level 1 and 2 should not be used in productions.** | |||
*/ | |||
static std::unique_ptr<Handle> make( | |||
megcoreComputingHandle_t computing_handle, int debug_level = 0); | |||
#if MEGDNN_WITH_CUDA | |||
static std::unique_ptr<Handle> make_cuda_handle( | |||
megcoreComputingHandle_t computing_handle); | |||
template <typename opr> | |||
std::unique_ptr<opr> create_cuda_operator(); | |||
static std::unique_ptr<Handle> make_cuda_handle( | |||
megcoreComputingHandle_t computing_handle); | |||
template <typename opr> | |||
std::unique_ptr<opr> create_cuda_operator(); | |||
#endif | |||
#if MEGDNN_WITH_ROCM | |||
static std::unique_ptr<Handle> make_rocm_handle( | |||
megcoreComputingHandle_t computing_handle); | |||
template <typename opr> | |||
std::unique_ptr<opr> create_rocm_operator(); | |||
static std::unique_ptr<Handle> make_rocm_handle( | |||
megcoreComputingHandle_t computing_handle); | |||
template <typename opr> | |||
std::unique_ptr<opr> create_rocm_operator(); | |||
#endif | |||
virtual ~Handle(); | |||
/*! | |||
* \brief Get the underlying megcore computing handle. | |||
*/ | |||
megcoreComputingHandle_t megcore_computing_handle() const { | |||
return m_computing_handle; | |||
} | |||
/*! | |||
* \brief set a callback function to be invoked when this handle is | |||
* destructed, so associated resources can be released (e.g. | |||
* computing handle) | |||
* | |||
* This function can be called at most once. | |||
*/ | |||
void set_destructor(const thin_function<void()> &d); | |||
/*! | |||
* \brief set a callback to be invoked when an operator is destructed | |||
* \param[in,out] cb the callback function; it would be set to the | |||
* previous callback function | |||
*/ | |||
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { | |||
cb.swap(m_on_opr_destructed); | |||
} | |||
void on_opr_destructed(OperatorBase* opr); | |||
/** | |||
* \brief Create operator of Opr type. | |||
*/ | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
/* | |||
* ============================================================= | |||
* Users should call functions below to query memory requirement. | |||
* ============================================================= | |||
*/ | |||
/** | |||
* \brief The internal data pointer of TensorND should be aligned to | |||
* alignment_requirement() in bytes. | |||
*/ | |||
virtual size_t alignment_requirement() const; | |||
//! get alignment in bytes for rows of image 2D tensor format | |||
virtual size_t image2d_pitch_alignment() const; | |||
//! get vendor type | |||
virtual HandleVendorType vendor_type() const; | |||
HandleType type() const { | |||
return m_handle_type; | |||
} | |||
/** | |||
* \brief Check is the layout satisfy cross device copy constraint. | |||
* 1. The handle of the src and the dst is the same kind | |||
* 2. The dst is continguous. | |||
*/ | |||
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); | |||
private: | |||
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||
volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||
megcoreComputingHandle_t m_computing_handle; | |||
const HandleType m_handle_type; | |||
thin_function<void()> m_destructor; | |||
thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||
Handle() = delete; | |||
Handle(const Handle &rhs) = delete; | |||
Handle &operator=(const Handle &rhs) = delete; | |||
virtual ~Handle(); | |||
/*! | |||
* \brief Get the underlying megcore computing handle. | |||
*/ | |||
megcoreComputingHandle_t megcore_computing_handle() const { | |||
return m_computing_handle; | |||
} | |||
/*! | |||
* \brief set a callback function to be invoked when this handle is | |||
* destructed, so associated resources can be released (e.g. | |||
* computing handle) | |||
* | |||
* This function can be called at most once. | |||
*/ | |||
void set_destructor(const thin_function<void()>& d); | |||
/*! | |||
* \brief set a callback to be invoked when an operator is destructed | |||
* \param[in,out] cb the callback function; it would be set to the | |||
* previous callback function | |||
*/ | |||
void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) { | |||
cb.swap(m_on_opr_destructed); | |||
} | |||
void on_opr_destructed(OperatorBase* opr); | |||
/** | |||
* \brief Create operator of Opr type. | |||
*/ | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
/* | |||
* ============================================================= | |||
* Users should call functions below to query memory requirement. | |||
* ============================================================= | |||
*/ | |||
/** | |||
* \brief The internal data pointer of TensorND should be aligned to | |||
* alignment_requirement() in bytes. | |||
*/ | |||
virtual size_t alignment_requirement() const; | |||
//! get alignment in bytes for rows of image 2D tensor format | |||
virtual size_t image2d_pitch_alignment() const; | |||
//! get vendor type | |||
virtual HandleVendorType vendor_type() const; | |||
HandleType type() const { return m_handle_type; } | |||
/** | |||
* \brief Check is the layout satisfy cross device copy constraint. | |||
* 1. The handle of the src and the dst is the same kind | |||
* 2. The dst is continguous. | |||
*/ | |||
virtual bool check_cross_dev_copy_constraint(const TensorLayout& src); | |||
private: | |||
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||
volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||
megcoreComputingHandle_t m_computing_handle; | |||
const HandleType m_handle_type; | |||
thin_function<void()> m_destructor; | |||
thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||
Handle() = delete; | |||
Handle(const Handle& rhs) = delete; | |||
Handle& operator=(const Handle& rhs) = delete; | |||
}; | |||
} // namespace megdnn | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
@@ -49,8 +49,9 @@ public: | |||
mutable std::string m_input; | |||
public: | |||
Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, | |||
size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) | |||
Key(Handle* opr_handle, Algorithm::OprType opr_type, | |||
const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, | |||
const void* param_ptr = nullptr, size_t param_size = 0) | |||
: m_handle{opr_handle}, | |||
m_opr_type{static_cast<uint32_t>(opr_type)}, | |||
m_inp_layouts_ptr{inp_layouts_ptr}, | |||
@@ -16,20 +16,19 @@ | |||
* \brief iterate through small (usually used) ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | |||
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) | |||
cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__) | |||
/*! | |||
* \brief iterate through large (rarely used) ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | |||
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ | |||
cb(7, ##__VA_ARGS__) | |||
cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__) | |||
/*! | |||
* \brief iterate through all ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) | |||
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__) | |||
// vim: syntax=cpp.doxygen |
@@ -11,14 +11,14 @@ | |||
// intentional no header guard here | |||
#include "megdnn/handle.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/opr_result_defs.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "./visibility_prologue.h" | |||
#include <limits> | |||
#include <array> | |||
#include <limits> | |||
#ifndef _megdnn_in | |||
#define _megdnn_in | |||
@@ -29,36 +29,37 @@ | |||
#endif | |||
#ifndef _megdnn_tensor_in | |||
#define _megdnn_tensor_in const TensorND & | |||
#define _megdnn_tensor_in const TensorND& | |||
#endif | |||
#ifndef _megdnn_tensor_out | |||
#define _megdnn_tensor_out const TensorND & | |||
#define _megdnn_tensor_out const TensorND& | |||
#endif | |||
#ifndef _megdnn_tensor_inout | |||
#define _megdnn_tensor_inout const TensorND & | |||
#define _megdnn_tensor_inout const TensorND& | |||
#endif | |||
#ifndef _megdnn_workspace | |||
#define _megdnn_workspace const Workspace & | |||
#define _megdnn_workspace const Workspace& | |||
#endif | |||
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
public: \ | |||
_opr_name(Handle *handle): _base_name(handle) {} \ | |||
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
public: \ | |||
_opr_name(Handle* handle) : _base_name(handle) {} | |||
#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | |||
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ | |||
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; | |||
#define DEF_OPR_PARAM(_pname) \ | |||
public: \ | |||
using Param = param::_pname; \ | |||
Param& param() { return m_param; } \ | |||
const Param& param() const { return m_param; } \ | |||
protected: \ | |||
Param m_param | |||
#define DEF_OPR_PARAM(_pname) \ | |||
public: \ | |||
using Param = param::_pname; \ | |||
Param& param() { return m_param; } \ | |||
const Param& param() const { return m_param; } \ | |||
\ | |||
protected: \ | |||
Param m_param | |||
// vim: syntax=cpp.doxygen |
@@ -20,4 +20,3 @@ | |||
#endif | |||
// vim: syntax=cpp.doxygen | |||
@@ -16,25 +16,21 @@ | |||
namespace megdnn { | |||
namespace opr_result { | |||
struct Checksum { | |||
uint32_t checksum; | |||
union { | |||
int32_t iv; | |||
float fv; | |||
} last_val; | |||
bool operator == (const Checksum &rhs) const { | |||
return checksum == rhs.checksum && | |||
last_val.iv == rhs.last_val.iv; | |||
} | |||
bool operator != (const Checksum &rhs) const { | |||
return !operator==(rhs); | |||
} | |||
}; | |||
} // namespace opr_result | |||
} // namespace megdnn | |||
struct Checksum { | |||
uint32_t checksum; | |||
union { | |||
int32_t iv; | |||
float fv; | |||
} last_val; | |||
bool operator==(const Checksum& rhs) const { | |||
return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv; | |||
} | |||
bool operator!=(const Checksum& rhs) const { return !operator==(rhs); } | |||
}; | |||
} // namespace opr_result | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -12,11 +12,11 @@ | |||
#include "megdnn/oprs/cv.h" | |||
#include "megdnn/oprs/general.h" | |||
#include "megdnn/oprs/imgproc.h" | |||
#include "megdnn/oprs/linalg.h" | |||
#include "megdnn/oprs/nn.h" | |||
#include "megdnn/oprs/nn_int.h" | |||
#include "megdnn/oprs/imgproc.h" | |||
#include "megdnn/oprs/utils.h" | |||
#include "megdnn/oprs/linalg.h" | |||
template <typename Opr> | |||
struct OprArityTrait; | |||
@@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); | |||
#undef INST_ARITY | |||
// vim: syntax=cpp.doxygen |
@@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t { | |||
INT8X8X16 = 1 << 4, | |||
INT16X16X32 = 1 << 5, | |||
INT4X4X16 = 1 << 6, | |||
QINT4x4x32 = 1 << 7, | |||
QINT4x4x32 = 1 << 7, | |||
}; | |||
/*! | |||
@@ -195,16 +195,16 @@ public: | |||
Handle::HandleType handle_type() const { return m_handle_type; } | |||
Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | |||
Info info() const { | |||
return {desc(), attribute()}; | |||
} | |||
Info info() const { return {desc(), attribute()}; } | |||
template <typename T> | |||
static void serialize_write_pod(const T& val, std::string& result) { | |||
static_assert(std::is_trivially_copyable<T>::value, | |||
"type should be trivially copyable"); | |||
static_assert(!std::is_pointer<T>::value, | |||
"serialize pointer is unsafe in eager execution mode"); | |||
static_assert( | |||
std::is_trivially_copyable<T>::value, | |||
"type should be trivially copyable"); | |||
static_assert( | |||
!std::is_pointer<T>::value, | |||
"serialize pointer is unsafe in eager execution mode"); | |||
result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | |||
} | |||
@@ -231,9 +231,8 @@ public: | |||
return ret; | |||
} | |||
static std::string deserialize_read_pod(const std::string& data, | |||
size_t offset = 0, | |||
size_t size = 0) { | |||
static std::string deserialize_read_pod( | |||
const std::string& data, size_t offset = 0, size_t size = 0) { | |||
return std::string(data.data() + offset, size); | |||
} | |||
@@ -286,8 +285,8 @@ public: | |||
* \param layouts origin layouts of the parent opr | |||
* \param opr parent opr | |||
*/ | |||
virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&, | |||
const OperatorBase*) const { | |||
virtual std::vector<SearchItem> get_subopr_list( | |||
const TensorLayoutArray&, const OperatorBase*) const { | |||
return {}; | |||
} | |||
@@ -333,9 +332,7 @@ public: | |||
ExecutionPolicy& execution_policy() { return m_execution_policy; } | |||
const ExecutionPolicy& execution_policy() const { | |||
return m_execution_policy; | |||
} | |||
const ExecutionPolicy& execution_policy() const { return m_execution_policy; } | |||
virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | |||
@@ -355,8 +352,8 @@ public: | |||
using AlgoAttribute = detail::Algorithm::Attribute; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -364,8 +361,8 @@ public: | |||
return ret; | |||
} | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
const TensorLayout& p1) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
const TensorLayout& p0, const TensorLayout& p1) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -382,12 +379,11 @@ public: | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes, | |||
positive_attr, negative_attr) | |||
return get_algorithm_heuristic( | |||
p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr) | |||
->info(); | |||
} | |||
@@ -408,8 +404,7 @@ protected: | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
}; | |||
@@ -423,9 +418,8 @@ public: | |||
using AlgoAttribute = detail::Algorithm::Attribute; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -433,9 +427,8 @@ public: | |||
return ret; | |||
} | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -451,14 +444,13 @@ public: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||
positive_attr, negative_attr) | |||
return get_algorithm_heuristic( | |||
p0, p1, p2, workspace_limit_in_bytes, positive_attr, | |||
negative_attr) | |||
->info(); | |||
} | |||
@@ -467,11 +459,9 @@ protected: | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2) = 0; | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2) = 0; | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
@@ -480,10 +470,8 @@ protected: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
}; | |||
@@ -497,10 +485,9 @@ public: | |||
using AlgoAttribute = detail::Algorithm::Attribute; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -508,10 +495,9 @@ public: | |||
return ret; | |||
} | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -527,14 +513,14 @@ public: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||
positive_attr, negative_attr) | |||
return get_algorithm_heuristic( | |||
p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr, | |||
negative_attr) | |||
->info(); | |||
} | |||
@@ -543,11 +529,11 @@ protected: | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3) = 0; | |||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
@@ -556,10 +542,9 @@ protected: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
}; | |||
@@ -573,11 +558,9 @@ public: | |||
using AlgoAttribute = detail::Algorithm::Attribute; | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3, | |||
const TensorLayout& p4) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -585,11 +568,9 @@ public: | |||
return ret; | |||
} | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
const TensorLayout& p3, | |||
const TensorLayout& p4) { | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | |||
ret.emplace_back(algo->info()); | |||
@@ -605,16 +586,14 @@ public: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||
workspace_limit_in_bytes, positive_attr, | |||
negative_attr) | |||
return get_algorithm_heuristic( | |||
p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr, | |||
negative_attr) | |||
->info(); | |||
} | |||
@@ -622,14 +601,12 @@ protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4) = 0; | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4) = 0; | |||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4) = 0; | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
@@ -638,11 +615,9 @@ protected: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
}; | |||
@@ -657,9 +632,8 @@ public: | |||
//! get all possible algorithm decriptions for the specified layouts | |||
std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
@@ -669,9 +643,8 @@ public: | |||
} | |||
std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) { | |||
std::vector<AlgorithmInfo> ret; | |||
for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
@@ -687,17 +660,15 @@ public: | |||
* The selected algorithm should not use workspace more than | |||
*/ | |||
AlgorithmInfo get_algorithm_info_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||
workspace_limit_in_bytes, positive_attr, | |||
negative_attr) | |||
return get_algorithm_heuristic( | |||
p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes, | |||
positive_attr, negative_attr) | |||
->info(); | |||
} | |||
@@ -705,15 +676,13 @@ protected: | |||
~MultiAlgoOpr() = default; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
/** | |||
@@ -723,12 +692,10 @@ protected: | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
}; | |||
@@ -31,15 +31,17 @@ class FlipForward : public FlipBase { | |||
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Flip = FlipForward; | |||
@@ -56,15 +58,17 @@ class RotateForward : public RotateBase { | |||
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Rotate = RotateForward; | |||
@@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { | |||
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using ROICopy = ROICopyForward; | |||
@@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { | |||
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using CvtColor = CvtColorForward; | |||
@@ -130,8 +138,9 @@ public: | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, | |||
const TensorLayout& dst); | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& trans, | |||
const TensorLayout& dst); | |||
std::string param_msg() const; | |||
int get_real_coord(int p, int len); | |||
}; | |||
@@ -148,15 +157,17 @@ public: | |||
* \warning src, trans, border_value, dst should be contiguous | |||
* The size of trans is N * 2 * 3 | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& trans, | |||
const TensorLayout& dst) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& trans, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& trans, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using WarpAffine = WarpAffineForward; | |||
@@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { | |||
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using GaussianBlur = GaussianBlurForward; | |||
@@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { | |||
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Resize = ResizeForward; | |||
@@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { | |||
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& mat) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& diff, const TensorLayout& mat) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& mat, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& diff, const TensorLayout& mat, | |||
size_t workspace_in_bytes); | |||
}; | |||
/** | |||
@@ -251,29 +268,32 @@ public: | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& dst); | |||
void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||
TensorLayout& dst); | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& dst); | |||
void deduce_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||
}; | |||
class RemapForward : public RemapBase { | |||
DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, | |||
TensorLayout& dst); | |||
void deduce_layout( | |||
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& map_xy, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
using Remap = RemapForward; | |||
@@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { | |||
DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& map_xy, const TensorLayout& diff, | |||
const TensorLayout& grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||
const TensorLayout& grad, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& map_xy, const TensorLayout& diff, | |||
const TensorLayout& grad, size_t workspace_in_bytes); | |||
}; | |||
class RemapBackwardMat : public RemapBase { | |||
DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& map_xy, | |||
const TensorLayout& diff, | |||
const TensorLayout& grad) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& diff, const TensorLayout& grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& map_xy, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_in_bytes); | |||
}; | |||
class SeparableFilterBase : public OperatorBase { | |||
@@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { | |||
DEF_OPR_PARAM(SeparableFilter); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, TensorLayout& dst); | |||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, | |||
const TensorLayout& dst); | |||
void deduce_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, TensorLayout& dst); | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, const TensorLayout& dst); | |||
}; | |||
class SeparableFilterForward : public SeparableFilterBase { | |||
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, | |||
const TensorLayout& dst) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||
_megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& filter_x, | |||
const TensorLayout& filter_y, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using SeparableFilter = SeparableFilterForward; | |||
@@ -13,173 +13,162 @@ | |||
namespace megdnn { | |||
class WarpPerspectiveBase: public OperatorBase { | |||
class WarpPerspectiveBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | |||
DEF_OPR_PARAM(WarpPerspective); | |||
public: | |||
using InterpolationMode = Param::InterpolationMode; | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
const TensorLayout &dst) { | |||
check_layout_fwd(src, mat, {}, dst); | |||
} | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
const TensorLayout &mat_idx, const TensorLayout &dst); | |||
std::string param_msg() const; | |||
int get_real_coord(int p, int len); | |||
public: | |||
using InterpolationMode = Param::InterpolationMode; | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||
check_layout_fwd(src, mat, {}, dst); | |||
} | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& dst); | |||
std::string param_msg() const; | |||
int get_real_coord(int p, int len); | |||
}; | |||
class WarpPerspectiveForward: public WarpPerspectiveBase { | |||
class WarpPerspectiveForward : public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | |||
public: | |||
/** | |||
* \param[in] src (n, channel, in_height, in_width) | |||
* \param[in] mat (n, 3, 3) | |||
* \param[out] dst (n, channel, out_height, out_width) | |||
* | |||
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||
* | |||
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||
* | |||
* src and dst can have different shapes, as long as their n and c agree. | |||
* src, mat and dst should be contiguous. | |||
*/ | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
exec(src, mat, {}, dst, workspace); | |||
} | |||
/** | |||
* \p src should have batch size m, and \p mat and \p mat_idx should | |||
* both have batch size n. Each item in \p mat_idx must be in the range | |||
* of [0, m-1]. | |||
* | |||
* \param mat_idx the indices of input image that each matrix in \p mat | |||
* should act on. It can also be empty and in such case \p mat | |||
* should have the same batch size as \p src. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &dst) { | |||
return get_workspace_in_bytes(src, mat, {}, dst); | |||
} | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
public: | |||
/** | |||
* \param[in] src (n, channel, in_height, in_width) | |||
* \param[in] mat (n, 3, 3) | |||
* \param[out] dst (n, channel, out_height, out_width) | |||
* | |||
* \see | |||
* http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||
* | |||
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||
* | |||
* src and dst can have different shapes, as long as their n and c agree. | |||
* src, mat and dst should be contiguous. | |||
*/ | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
exec(src, mat, {}, dst, workspace); | |||
} | |||
/** | |||
* \p src should have batch size m, and \p mat and \p mat_idx should | |||
* both have batch size n. Each item in \p mat_idx must be in the range | |||
* of [0, m-1]. | |||
* | |||
* \param mat_idx the indices of input image that each matrix in \p mat | |||
* should act on. It can also be empty and in such case \p mat | |||
* should have the same batch size as \p src. | |||
*/ | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||
return get_workspace_in_bytes(src, mat, {}, dst); | |||
} | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
void check_exec_allow_nhwc_mat_idx( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using WarpPerspective = WarpPerspectiveForward; | |||
class WarpPerspectiveBackwardData: public WarpPerspectiveBase { | |||
class WarpPerspectiveBackwardData : public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | |||
public: | |||
/** | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. src | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
void exec(_megdnn_tensor_in mat, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
exec(mat, {}, diff, grad, workspace); | |||
} | |||
virtual void exec(_megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes(const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) { | |||
return get_workspace_in_bytes(mat, {}, diff, grad); | |||
} | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes); | |||
public: | |||
/** | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. src | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
void exec( | |||
_megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
exec(mat, {}, diff, grad, workspace); | |||
} | |||
virtual void exec( | |||
_megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& mat, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
return get_workspace_in_bytes(mat, {}, diff, grad); | |||
} | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& mat, const TensorLayout& mat_idx, | |||
const TensorLayout& diff, const TensorLayout& grad) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& mat, const TensorLayout& mat_idx, | |||
const TensorLayout& diff, const TensorLayout& grad, | |||
size_t workspace_in_bytes); | |||
}; | |||
class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||
class WarpPerspectiveBackwardMat : public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | |||
public: | |||
/** | |||
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. mat | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) { | |||
exec(src, mat, {}, diff, grad, workspace); | |||
} | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) { | |||
return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||
} | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes); | |||
public: | |||
/** | |||
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. mat | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
exec(src, mat, {}, diff, grad, workspace); | |||
} | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||
} | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& diff, | |||
const TensorLayout& grad) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& mat, | |||
const TensorLayout& mat_idx, const TensorLayout& diff, | |||
const TensorLayout& grad, size_t workspace_in_bytes); | |||
}; | |||
class DctChannelSelectForward : public OperatorBase { | |||
@@ -194,37 +183,32 @@ public: | |||
* \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | |||
* \param[out] workspace temporary workspace to perform forward | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
const TensorLayout& dst) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||
_megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& src, const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
const TensorLayout& dst); | |||
void deduce_layout_fwd(const TensorLayout& src, | |||
const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, | |||
TensorLayout& dst); | |||
void check_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, const TensorLayout& dst); | |||
void deduce_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& mask_offset, | |||
const TensorLayout& mask_val, TensorLayout& dst); | |||
std::string param_msg() const; | |||
}; | |||
} // namespace megdnn | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -33,22 +33,22 @@ public: | |||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType &C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType& C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
static Algorithm::OprType get_opr_type() { | |||
return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | |||
} | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_in_bytes); | |||
}; | |||
using BatchedMatrixMul = BatchedMatrixMulForward; | |||
@@ -70,24 +70,24 @@ public: | |||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType& C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
static size_t pack_size (const Param::Format format); | |||
static size_t pack_size(const Param::Format format); | |||
static Algorithm::OprType get_opr_type() { | |||
return Algorithm::OprType::MATRIX_MUL_FORWARD; | |||
} | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_in_bytes); | |||
}; | |||
using MatrixMul = MatrixMulForward; | |||
@@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst); | |||
protected: | |||
/*! | |||
@@ -116,8 +116,7 @@ protected: | |||
* | |||
* Note that \p batch and \p n can be null | |||
*/ | |||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
size_t* n); | |||
static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n); | |||
/*! | |||
* \brief canonize and validate input params for exec() impls | |||
@@ -125,11 +124,12 @@ protected: | |||
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not | |||
* be null | |||
*/ | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
_megdnn_workspace workspace, size_t* batch, size_t* n); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
_megdnn_workspace workspace, size_t* batch, size_t* n); | |||
virtual size_t get_workspace_in_bytes(size_t batch, size_t n, | |||
size_t dtype_size) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
size_t batch, size_t n, size_t dtype_size) = 0; | |||
}; | |||
//! inter-product of two vectors | |||
@@ -147,17 +147,17 @@ public: | |||
* A, B, C must be contiguous. A and B must have the same 1-dimensional | |||
* shape and non-negative strides. C must be scalar. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Dot = DotForward; | |||
@@ -193,23 +193,24 @@ public: | |||
* if compute_uv is false (default to true). | |||
* | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, | |||
_megdnn_tensor_out s, _megdnn_tensor_out vt, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& u, | |||
TensorLayout& s, TensorLayout& vt); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& u, const TensorLayout& s, | |||
const TensorLayout& vt); | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s, | |||
_megdnn_tensor_out vt, _megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& src, TensorLayout& u, TensorLayout& s, | |||
TensorLayout& vt); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||
const TensorLayout& vt); | |||
protected: | |||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
size_t* m, size_t* n); | |||
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, | |||
size_t dtype_size) = 0; | |||
void check_exec(const TensorLayout& src, const TensorLayout& u, | |||
const TensorLayout& s, const TensorLayout& vt, | |||
size_t workspace_in_bytes); | |||
static void canonize_params( | |||
const TensorLayout& layout, size_t* batch, size_t* m, size_t* n); | |||
virtual size_t get_workspace_in_bytes( | |||
size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0; | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||
const TensorLayout& vt, size_t workspace_in_bytes); | |||
}; | |||
using SVD = SVDForward; | |||
@@ -36,7 +36,7 @@ public: | |||
struct ModeTrait { | |||
uint32_t arity = 0; //!< number of inputs needed | |||
CheckDtypeFunc check_inp[MAX_ARITY]; | |||
SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||
SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||
bool need_specify_out_dtype = | |||
false; //!< the dtype should be setup externally, otherwise | |||
//!< would be inferred by check_out(dtype, false) | |||
@@ -46,13 +46,10 @@ public: | |||
static const ModeTrait& from_mode(Mode mode); | |||
}; | |||
virtual void exec(_megdnn_in const TensorNDArray& src, | |||
_megdnn_tensor_out dst) = 0; | |||
virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; | |||
//! get trait of current mode | |||
const ModeTrait& mode_trait() const { | |||
return ModeTrait::from_mode(m_param.mode); | |||
} | |||
const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } | |||
//! deduce output layout | |||
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | |||
@@ -60,8 +57,8 @@ public: | |||
protected: | |||
//! throw exception if incorrect layout; broadcast input shape to | |||
//! output shape | |||
void check_layout_and_broadcast(const TensorLayoutPtrArray& src, | |||
const TensorLayout& dst); | |||
void check_layout_and_broadcast( | |||
const TensorLayoutPtrArray& src, const TensorLayout& dst); | |||
}; | |||
} // namespace megdnn | |||
@@ -15,84 +15,97 @@ | |||
namespace megdnn { | |||
//! base class for random number generators | |||
class RNGBase: public OperatorBase { | |||
class RNGBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | |||
public: | |||
virtual void exec(_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||
protected: | |||
virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||
public: | |||
virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; | |||
protected: | |||
virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0; | |||
}; | |||
//! sample from poisson distribution | |||
class PoissonRNG: public OperatorBase { | |||
class PoissonRNG : public OperatorBase { | |||
DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | |||
DEF_OPR_PARAM(PoissonRNG); | |||
public: | |||
virtual void exec(_megdnn_tensor_in lam, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& lam, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& lam, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
//! sample from beta distribution | |||
class BetaRNG: public OperatorBase { | |||
class BetaRNG : public OperatorBase { | |||
DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | |||
DEF_OPR_PARAM(BetaRNG); | |||
public: | |||
virtual void exec(_megdnn_tensor_in alpha, | |||
_megdnn_tensor_in beta, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||
const TensorLayout &beta, const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||
const TensorLayout &dst, size_t workspace_in_bytes); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& alpha, const TensorLayout& beta, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& alpha, const TensorLayout& beta, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
//! sample from gamma distribution | |||
class GammaRNG: public OperatorBase { | |||
class GammaRNG : public OperatorBase { | |||
DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | |||
DEF_OPR_PARAM(GammaRNG); | |||
public: | |||
virtual void exec(_megdnn_tensor_in shape, | |||
_megdnn_tensor_in scale, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||
const TensorLayout &scale, const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||
const TensorLayout &dst, size_t workspace_in_bytes); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& shape, const TensorLayout& scale, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& shape, const TensorLayout& scale, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
//! sample from uniform distribution on the interval (0, 1] | |||
class UniformRNG: public RNGBase { | |||
class UniformRNG : public RNGBase { | |||
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | |||
DEF_OPR_PARAM(UniformRNG); | |||
protected: | |||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
protected: | |||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
//! sample from gaussian distribution | |||
class GaussianRNG: public RNGBase { | |||
class GaussianRNG : public RNGBase { | |||
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | |||
DEF_OPR_PARAM(GaussianRNG); | |||
protected: | |||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
protected: | |||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
class PermutationRNG: public RNGBase { | |||
class PermutationRNG : public RNGBase { | |||
DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | |||
DEF_OPR_PARAM(PermutationRNG); | |||
protected: | |||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
protected: | |||
void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
class ShuffleRNGForward : public OperatorBase { | |||
@@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { | |||
DEF_OPR_PARAM(ShuffleRNG); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_tensor_out indices, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||
TensorLayout& indices); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst, | |||
const TensorLayout& indices) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
const TensorLayout& indices) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
const TensorLayout& indices, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
const TensorLayout& indices, size_t workspace_in_bytes); | |||
}; | |||
using ShuffleRNG = ShuffleRNGForward; | |||
@@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { | |||
DEF_OPR_PARAM(ShuffleRNG); | |||
public: | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& indices, | |||
const TensorLayout& grad) = 0; | |||
virtual void exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& diff, const TensorLayout& indices, | |||
const TensorLayout& grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||
const TensorLayout& grad, size_t workspace_in_bytes); | |||
void check_exec( | |||
const TensorLayout& diff, const TensorLayout& indices, | |||
const TensorLayout& grad, size_t workspace_in_bytes); | |||
}; | |||
/*! | |||
* \brief sleep for specific time on the computing device; useful for testing | |||
* async problems | |||
*/ | |||
class SleepForward: public OperatorBase { | |||
class SleepForward : public OperatorBase { | |||
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | |||
DEF_OPR_PARAM(Sleep); | |||
public: | |||
virtual void exec() = 0; | |||
public: | |||
virtual void exec() = 0; | |||
}; | |||
using Sleep = SleepForward; | |||
@@ -149,20 +165,19 @@ using Sleep = SleepForward; | |||
* | |||
* data must be a one-dimensional contiguous tensor with dtype byte | |||
*/ | |||
class ChecksumForward: public OperatorBase { | |||
class ChecksumForward : public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | |||
public: | |||
using Result = opr_result::Checksum; | |||
public: | |||
using Result = opr_result::Checksum; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; | |||
virtual Result exec(_megdnn_tensor_in data, | |||
_megdnn_workspace workspace) = 0; | |||
virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); | |||
protected: | |||
void check_exec(const TensorLayout& layout, size_t workspace_in_bytes); | |||
}; | |||
using Checksum = ChecksumForward; | |||
@@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | |||
public: | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, | |||
const TensorLayout& layout2) = 0; | |||
public: | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& layout1, const TensorLayout& layout2) = 0; | |||
virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||
_megdnn_workspace workspace) = 0; | |||
virtual float exec( | |||
_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||
_megdnn_workspace workspace) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& layout1, | |||
const TensorLayout& layout2, size_t workspace_in_bytes); | |||
protected: | |||
void check_exec( | |||
const TensorLayout& layout1, const TensorLayout& layout2, | |||
size_t workspace_in_bytes); | |||
}; | |||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||
const param::ConvBias::Format format); | |||
bool check_bias_share_in_channel( | |||
const TensorLayout& bias, const param::ConvBias::Format format); | |||
} // namespace megdnn | |||
@@ -18,9 +18,9 @@ | |||
namespace megdnn { | |||
enum class TensorFormat::Type { | |||
DEFAULT = 0, //!< see DefaultTensorFormat | |||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||
DEFAULT = 0, //!< see DefaultTensorFormat | |||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||
}; | |||
class TensorFormat::ImplBase { | |||
@@ -33,8 +33,7 @@ public: | |||
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | |||
virtual TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const = 0; | |||
virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0; | |||
virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | |||
@@ -79,8 +78,7 @@ public: | |||
*/ | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
@@ -88,8 +86,7 @@ public: | |||
void serialize_append(std::string& result) const override; | |||
static TensorFormat make(); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
}; | |||
namespace detail { | |||
@@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||
size_t m_align_axis, m_align_size_in_elements_log2; | |||
protected: | |||
Image2DTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_elements); | |||
Image2DTensorFormatBase( | |||
Type type, size_t align_axis, size_t align_size_in_elements); | |||
virtual ~Image2DTensorFormatBase() = default; | |||
public: | |||
@@ -129,9 +126,7 @@ public: | |||
size_t align_axis() const { return m_align_axis; } | |||
size_t align_size_in_elements_log2() const { | |||
return m_align_size_in_elements_log2; | |||
} | |||
size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; } | |||
std::string to_string() const override; | |||
@@ -145,6 +140,7 @@ public: | |||
size_t image_height(const TensorLayout& layout) const; | |||
void serialize_append(std::string& result) const override; | |||
protected: | |||
struct SerializePack { | |||
uint8_t align_axis; | |||
@@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||
* align COUNT, but mdl needs align size in byte, which equal to | |||
* (image_width algin count) * sizeof(data_type) * pixel_size | |||
*/ | |||
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, | |||
const TensorLayout& layout) const; | |||
size_t image_pitch_alignment_in_bytes( | |||
size_t align_size_in_elements, const TensorLayout& layout) const; | |||
protected: | |||
Image2DPackedTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
: detail::Image2DTensorFormatBase(type, align_axis, | |||
align_size_in_elements), | |||
Image2DPackedTensorFormatBase( | |||
Type type, size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
: detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements), | |||
m_vendor_type(vendor_type) {} | |||
virtual ~Image2DPackedTensorFormatBase() = default; | |||
@@ -197,13 +192,12 @@ public: | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
}; | |||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
/*! | |||
* \brief used for tensors storing lowbit data | |||
* \brief used for tensors storing lowbit data | |||
* | |||
* \param m_size_nbits size in bits of elements in the tensor | |||
* \param m_align_size_in_bits aligned size in bits | |||
@@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { | |||
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | |||
protected: //? | |||
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, | |||
size_t align_size_in_bits); | |||
LowbitsAlignedTensorFormatBase( | |||
Type type, size_t size_nbits, size_t align_size_in_bits); | |||
virtual ~LowbitsAlignedTensorFormatBase() = default; | |||
public: | |||
size_t align_size_in_bits() const { return m_align_size_in_bits; } | |||
size_t size_nbits() const { return m_size_nbits; } | |||
std::string to_string() const override; | |||
@@ -238,8 +232,8 @@ public: | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
protected: | |||
struct SerializePack { | |||
uint8_t size_nbits; | |||
@@ -254,16 +248,14 @@ protected: | |||
* | |||
* This is used for OpenCL. | |||
*/ | |||
class Image2DPack4TensorFormat final | |||
: public detail::Image2DPack4TensorFormatBase { | |||
class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase { | |||
public: | |||
static constexpr Type TYPE = Type::IMAGE2D_PACK4; | |||
//! for internal usage or test purposes | |||
static TensorFormat make_raw(size_t align_axis, | |||
size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type = | |||
Handle::HandleVendorType::NOT_SPEC); | |||
static TensorFormat make_raw( | |||
size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC); | |||
static TensorFormat make(size_t align_axis, const Handle* handle); | |||
@@ -273,13 +265,11 @@ public: | |||
* Note that the alignment may be different if deserialized on another | |||
* handle | |||
*/ | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
static bool is_valid_image(const TensorLayout& layout) { | |||
if (layout.format.type() == TYPE) { | |||
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( | |||
layout); | |||
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout); | |||
return true; | |||
} | |||
return false; | |||
@@ -288,8 +278,9 @@ public: | |||
TensorFormat change_axis(size_t axis) const override; | |||
private: | |||
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
Image2DPack4TensorFormat( | |||
size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
: detail::Image2DPack4TensorFormatBase( | |||
TYPE, align_axis, align_size_in_elements, vendor_type) {} | |||
}; | |||
@@ -306,13 +297,12 @@ public: | |||
static TensorFormat make(size_t size_nbits); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
static bool is_valid_layout(const TensorLayout& layout) { | |||
if (layout.format.type() == TYPE) { | |||
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() | |||
.assert_valid(layout); | |||
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid( | |||
layout); | |||
return true; | |||
} | |||
return false; | |||
@@ -320,8 +310,7 @@ public: | |||
private: | |||
LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | |||
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, | |||
BYTE_IN_BITS) {} | |||
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {} | |||
}; | |||
} // namespace megdnn | |||
@@ -167,13 +167,11 @@ public: | |||
TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | |||
Iter begin() const { | |||
return Iter::make(const_cast<TensorND&>(m_tensor), 0); | |||
} | |||
Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); } | |||
Iter end() const { | |||
return Iter::make(const_cast<TensorND&>(m_tensor), | |||
m_tensor.layout.total_nr_elems()); | |||
return Iter::make( | |||
const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems()); | |||
} | |||
}; | |||
/*! | |||
@@ -11,19 +11,19 @@ | |||
#pragma once | |||
#include <type_traits> | |||
#include <cstdlib> | |||
#include <functional> | |||
#include <utility> | |||
#include <memory> | |||
#include <cstdlib> | |||
#include <type_traits> | |||
#include <utility> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
template<typename Signature> | |||
template <typename Signature> | |||
using thin_function = ::std::function<Signature>; | |||
} // namespace megdnn | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
@@ -58,18 +58,16 @@ protected: | |||
m_end_ptr(first_elm), | |||
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | |||
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, | |||
size_t type_size); | |||
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size); | |||
public: | |||
size_t size_in_bytes() const { | |||
return size_t(static_cast<char*>(m_end_ptr) - | |||
static_cast<char*>(m_begin_ptr)); | |||
return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr)); | |||
} | |||
size_t capacity_in_bytes() const { | |||
return size_t(static_cast<char*>(m_capacity_ptr) - | |||
static_cast<char*>(m_begin_ptr)); | |||
return size_t( | |||
static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr)); | |||
} | |||
bool empty() const { return m_begin_ptr == m_end_ptr; } | |||
@@ -85,20 +83,15 @@ private: | |||
U m_first_elm; | |||
protected: | |||
SmallVectorTemplateCommon(size_t size) | |||
: SmallVectorBase(&m_first_elm, size) {} | |||
SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {} | |||
void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | |||
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | |||
} | |||
bool is_small() { | |||
return m_begin_ptr == static_cast<const void*>(&m_first_elm); | |||
} | |||
bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); } | |||
void reset_to_small() { | |||
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; | |||
} | |||
void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; } | |||
void set_end(T* p) { m_end_ptr = p; } | |||
@@ -128,20 +121,12 @@ protected: | |||
public: | |||
// forwarding iterator creation | |||
iterator begin() { return static_cast<iterator>(m_begin_ptr); } | |||
const_iterator begin() const { | |||
return static_cast<const_iterator>(m_begin_ptr); | |||
} | |||
const_iterator cbegin() const { | |||
return static_cast<const_iterator>(m_begin_ptr); | |||
} | |||
const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||
const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||
iterator end() { return static_cast<iterator>(m_end_ptr); } | |||
const_iterator end() const { | |||
return static_cast<const_iterator>(m_end_ptr); | |||
} | |||
const_iterator cend() const { | |||
return static_cast<const_iterator>(m_end_ptr); | |||
} | |||
const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); } | |||
const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); } | |||
reference at(size_type idx) { | |||
if (idx >= size()) { | |||
@@ -167,13 +152,9 @@ public: | |||
// reverse iterator creation method. | |||
reverse_iterator rbegin() { return reverse_iterator(end()); } | |||
const_reverse_iterator rbegin() const { | |||
return const_reverse_iterator(end()); | |||
} | |||
const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } | |||
reverse_iterator rend() { return reverse_iterator(begin()); } | |||
const_reverse_iterator rend() const { | |||
return const_reverse_iterator(begin()); | |||
} | |||
const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } | |||
pointer data() { return pointer(begin()); } | |||
const_pointer data() const { return const_pointer(begin()); } | |||
@@ -207,8 +188,8 @@ protected: | |||
template <typename It1, typename It2> | |||
static void uninitialized_move(It1 first, It1 last, It2 dest) { | |||
std::uninitialized_copy(std::make_move_iterator(first), | |||
std::make_move_iterator(last), dest); | |||
std::uninitialized_copy( | |||
std::make_move_iterator(first), std::make_move_iterator(last), dest); | |||
} | |||
template <typename It1, typename It2> | |||
@@ -293,9 +274,7 @@ protected: | |||
memcpy(dest, first, (last - first) * sizeof(T)); | |||
} | |||
void grow(size_t min_sz = 0) { | |||
this->grow_pod(min_sz * sizeof(T), sizeof(T)); | |||
} | |||
void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); } | |||
public: | |||
void push_back(const T& _elm) { | |||
@@ -318,8 +297,7 @@ public: | |||
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | |||
*/ | |||
template <typename T> | |||
class SmallVectorImpl | |||
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||
class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | |||
public: | |||
@@ -329,8 +307,7 @@ public: | |||
protected: | |||
explicit SmallVectorImpl(unsigned n) | |||
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { | |||
} | |||
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {} | |||
public: | |||
SmallVectorImpl(const SmallVectorImpl&) = delete; | |||
@@ -354,8 +331,7 @@ public: | |||
} else if (n > this->size()) { | |||
if (this->capacity() < n) | |||
this->grow(n); | |||
for (auto it = this->end(), end = this->begin() + n; it != end; | |||
++it) | |||
for (auto it = this->end(), end = this->begin() + n; it != end; ++it) | |||
new (&*it) T(); | |||
this->set_end(this->begin() + n); | |||
} | |||
@@ -389,10 +365,11 @@ public: | |||
void swap(SmallVectorImpl<T>& rhs); | |||
/// Add the specified range to the end of the SmallVector. | |||
template <typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
template < | |||
typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
void append(in_iter in_start, in_iter in_end) { | |||
size_type num_inputs = std::distance(in_start, in_end); | |||
// Grow allocated space if needed. | |||
@@ -432,10 +409,11 @@ public: | |||
std::uninitialized_fill(this->begin(), this->end(), elm); | |||
} | |||
template <typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
template < | |||
typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
void assign(in_iter in_start, in_iter in_end) { | |||
clear(); | |||
append(in_start, in_end); | |||
@@ -571,8 +549,7 @@ public: | |||
std::fill_n(it, num_overwritten, elm); | |||
// Insert the non-overwritten middle part. | |||
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, | |||
elm); | |||
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm); | |||
return it; | |||
} | |||
@@ -646,8 +623,7 @@ public: | |||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
this->grow(); | |||
} | |||
new (static_cast<void*>(this->end())) | |||
T(std::forward<ArgTypes>(args)...); | |||
new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...); | |||
this->set_end(this->end() + 1); | |||
} | |||
@@ -661,13 +637,11 @@ public: | |||
return std::equal(this->begin(), this->end(), rhs.begin()); | |||
} | |||
bool operator!=(const SmallVectorImpl<T>& rhs) const { | |||
return !(*this == rhs); | |||
} | |||
bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); } | |||
bool operator<(const SmallVectorImpl<T>& rhs) const { | |||
return std::lexicographical_compare(this->begin(), this->end(), | |||
rhs.begin(), rhs.end()); | |||
return std::lexicographical_compare( | |||
this->begin(), this->end(), rhs.begin(), rhs.end()); | |||
} | |||
}; | |||
@@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||
// Copy over the extra elms. | |||
if (this->size() > rhs.size()) { | |||
size_t elm_diff = this->size() - rhs.size(); | |||
this->uninitialized_move(this->begin() + num_shared, this->end(), | |||
rhs.end()); | |||
this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end()); | |||
rhs.set_end(rhs.end() + elm_diff); | |||
this->destroy_range(this->begin() + num_shared, this->end()); | |||
this->set_end(this->begin() + num_shared); | |||
} else if (rhs.size() > this->size()) { | |||
size_t elm_diff = rhs.size() - this->size(); | |||
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), | |||
this->end()); | |||
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end()); | |||
this->set_end(this->end() + elm_diff); | |||
this->destroy_range(rhs.begin() + num_shared, rhs.end()); | |||
rhs.set_end(rhs.begin() + num_shared); | |||
@@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||
} | |||
template <typename T> | |||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||
const SmallVectorImpl<T>& rhs) { | |||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) { | |||
if (this == &rhs) | |||
return *this; | |||
size_t rhs_sz = rhs.size(); | |||
@@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||
} else if (cur_sz) { | |||
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
} | |||
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), | |||
this->begin() + cur_sz); | |||
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||
this->set_end(this->begin() + rhs_sz); | |||
return *this; | |||
} | |||
@@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { | |||
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
} | |||
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), | |||
this->begin() + cur_sz); | |||
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||
this->set_end(this->begin() + rhs_sz); | |||
@@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> { | |||
public: | |||
SmallVector() : SmallVectorImpl<T>(N) {} | |||
explicit SmallVector(size_t size, const T& value = T()) | |||
: SmallVectorImpl<T>(N) { | |||
explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) { | |||
this->assign(size, value); | |||
} | |||
@@ -901,15 +869,13 @@ namespace std { | |||
/// Implement std::swap in terms of SmallVector swap. | |||
template <typename T> | |||
inline void swap(megdnn::SmallVectorImpl<T>& lhs, | |||
megdnn::SmallVectorImpl<T>& rhs) { | |||
inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) { | |||
lhs.swap(rhs); | |||
} | |||
/// Implement std::swap in terms of SmallVector swap. | |||
template <typename T, unsigned N> | |||
inline void swap(megdnn::SmallVector<T, N>& lhs, | |||
megdnn::SmallVector<T, N>& rhs) { | |||
inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) { | |||
lhs.swap(rhs); | |||
} | |||
} // end namespace std | |||
@@ -17,13 +17,13 @@ | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
struct Version { | |||
int major, minor, patch; | |||
}; | |||
struct Version { | |||
int major, minor, patch; | |||
}; | |||
//! get megdnn version of the binary | |||
Version get_version(); | |||
} | |||
//! get megdnn version of the binary | |||
Version get_version(); | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
@@ -22,18 +22,17 @@ using namespace aarch64; | |||
/* ===================== stride-2 algo ===================== */ | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | |||
bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.filter_meta.format == param::Convolution::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && | |||
param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
@@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | |||
return get_kimpls(param); | |||
@@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
return {}; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
@@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
size_t, size_t, size_t, size_t, size_t)>; | |||
using Func = std::function<void( | |||
const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t, | |||
size_t)>; | |||
Func conv = nullptr; | |||
if (FH == 2) { | |||
conv = fp16::conv_stride2::do_conv_2x2_stride2; | |||
@@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
bundle.set(kern_param.workspace_ptr); | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
{ncb_index.thread_id, 0, oc}); | |||
do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto copy_padding = [bundle]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
bundle.set(kern_param.workspace_ptr); | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto do_conv = [bundle, conv]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
bundle.set(kern_param.workspace_ptr); | |||
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
ncb_index.ndrange_id); | |||
do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
@@ -18,13 +18,13 @@ namespace aarch64 { | |||
/* ===================== stride-2 algo ===================== */ | |||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "ARMV8F16STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
bool usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
@@ -20,9 +20,9 @@ namespace aarch64 { | |||
namespace fp16 { | |||
namespace conv_stride2 { | |||
static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_2x2_stride2( | |||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 3; | |||
size_t mod4_left = width & 3; | |||
@@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
"5: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |||
: "r"(mod4_left), "w"(_k0123) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||
"v31"); | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
} | |||
} | |||
static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_3x3_stride2( | |||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 3; | |||
size_t mod3_left = width % 3; | |||
@@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
"3: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | |||
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29"); | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
} | |||
} | |||
static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_5x5_stride2( | |||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 3; | |||
size_t mod2_left = width & 1; | |||
@@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
const __fp16* r4 = src_ptr + IW * 4; | |||
register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | |||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||
MEGDNN_SIMD_LOADU(filter + 4); | |||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||
MEGDNN_SIMD_LOADU(filter + 8); | |||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||
MEGDNN_SIMD_LOADU(filter + 12); | |||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||
MEGDNN_SIMD_LOADU(filter + 16); | |||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||
MEGDNN_SIMD_LOADU(filter + 20); | |||
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = | |||
MEGDNN_SIMD_SET1(filter[24]); | |||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||
register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]); | |||
for (size_t i = 0; i < OH; i++) { | |||
asm volatile( | |||
@@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
"bne 2b \n" | |||
"3: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||
"+r"(r3), "+r"(r4) | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
"+r"(r4) | |||
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | |||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||
"r"(mod2_left) | |||
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31"); | |||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||
: "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
} | |||
} | |||
static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
__fp16* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_7x7_stride2( | |||
const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 3; | |||
@@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
const __fp16* r6 = src_ptr + IW * 6; | |||
register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | |||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||
MEGDNN_SIMD_LOADU(filter + 4); | |||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||
MEGDNN_SIMD_LOADU(filter + 8); | |||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||
MEGDNN_SIMD_LOADU(filter + 12); | |||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||
MEGDNN_SIMD_LOADU(filter + 16); | |||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||
MEGDNN_SIMD_LOADU(filter + 20); | |||
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = | |||
MEGDNN_SIMD_LOADU(filter + 24); | |||
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = | |||
MEGDNN_SIMD_LOADU(filter + 28); | |||
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = | |||
MEGDNN_SIMD_LOADU(filter + 32); | |||
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = | |||
MEGDNN_SIMD_LOADU(filter + 36); | |||
register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||
register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||
register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||
register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||
register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||
register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24); | |||
register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28); | |||
register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32); | |||
register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36); | |||
register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | |||
MEGDNN_SIMD_LOADU(filter + 40); | |||
register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | |||
MEGDNN_SIMD_LOADU(filter + 44); | |||
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = | |||
MEGDNN_SIMD_SET1(filter[48]); | |||
register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]); | |||
for (size_t i = 0; i < OH; i++) { | |||
asm volatile( | |||
@@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
"bne 2b \n" | |||
"3: \n" | |||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
"+r"(r4), "+r"(r5), "+r"(r6) | |||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||
"+r"(r5), "+r"(r6) | |||
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | |||
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | |||
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | |||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||
"w"(_k48484848) | |||
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", | |||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
"v26", "v27", "v28", "v29", "v30", "v31"); | |||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||
: "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -21,18 +21,17 @@ using namespace megdnn; | |||
using namespace aarch64; | |||
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && | |||
!fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
FH == fm.spatial[1] && | |||
param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
(FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
} | |||
MIDOUT_END(); | |||
@@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||
return get_kimpls(param); | |||
@@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
return {}; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
const NCBKernSizeParam& param) const { | |||
auto fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
@@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t group = fm.group; | |||
bool large_group = group >= param.nr_threads; | |||
using Func = std::function<void(const float*, const float*, float*, size_t, | |||
size_t, size_t, size_t, size_t)>; | |||
using Func = std::function<void( | |||
const float*, const float*, float*, size_t, size_t, size_t, size_t, | |||
size_t)>; | |||
Func conv = nullptr; | |||
if (FH == 2) { | |||
conv = fp32::conv_stride2::do_conv_2x2_stride2; | |||
@@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
conv = fp32::conv_stride2::do_conv_7x7_stride2; | |||
} | |||
WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||
float, float>::get_bundle_stride(param, large_group); | |||
WorkspaceBundle bundle = | |||
arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
param, large_group); | |||
SmallVector<NCBKern> ret_kerns; | |||
//! Dense conv and small group | |||
@@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
bundle.set(kern_param.workspace_ptr); | |||
for (size_t ic = 0; ic < IC; ic++) { | |||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, | |||
{ncb_index.thread_id, 0, ic}); | |||
} | |||
for (size_t oc = 0; oc < OC; oc++) { | |||
arm_common::MultithreadDirectConvCommon< | |||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||
ncb_index, conv, | |||
{ncb_index.thread_id, | |||
0, oc}); | |||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||
do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv, | |||
{ncb_index.thread_id, 0, oc}); | |||
} | |||
}; | |||
ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
} else { | |||
auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto copy_padding = [bundle]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
bundle.set(kern_param.workspace_ptr); | |||
arm_common::MultithreadDirectConvCommon<float, float>:: | |||
copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
ncb_index.ndrange_id); | |||
copy_padding_kern_stride( | |||
bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
auto do_conv = [bundle, conv]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
bundle.set(kern_param.workspace_ptr); | |||
arm_common::MultithreadDirectConvCommon< | |||
float, float>::do_conv_kern_stride(bundle, kern_param, | |||
ncb_index, conv, | |||
ncb_index.ndrange_id); | |||
arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||
}; | |||
ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
} | |||
@@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "ARMV8F32STRD2"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
bool usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
@@ -16,16 +16,15 @@ | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace fp32{ | |||
namespace fp32 { | |||
namespace conv_stride2 { | |||
//! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | |||
// refer to function do_conv_2x2_stride2_asm_unroll4 | |||
static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
float* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_2x2_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 2; | |||
size_t mod4_left = width & 3; | |||
@@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
"5: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |||
: "r"(mod4_left), "w"(_k0123) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||
"v31"); | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
} | |||
// refer to function do_conv_3x3_stride2_asm_unroll3 | |||
static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
float* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_3x3_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 2; | |||
size_t mod3_left = width % 3; | |||
@@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
"ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | |||
"ld2 {v5.4s, v6.4s}, [%3], #32 \n" | |||
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||
"ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | |||
"ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | |||
"fmla v0.4s, v2.4s, v22.4s \n" | |||
@@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
"3: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | |||
: "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29"); | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
} | |||
// refer to function do_conv_5x5_stride2_asm_unroll2 | |||
static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
float* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_5x5_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 2; | |||
size_t mod2_left = width & 1; | |||
@@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
"bne 2b \n" | |||
"3: \n" | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||
"+r"(r3), "+r"(r4) | |||
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
"+r"(r4) | |||
: "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | |||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||
"r"(mod2_left) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24"); | |||
"w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
} | |||
// refer to function do_conv_7x7_stride2_asm_unroll2 | |||
static void do_conv_7x7_stride2(const float* src, const float* filter, | |||
float* dst, size_t IH, size_t IW, size_t OH, | |||
size_t OW, size_t IC) { | |||
static void do_conv_7x7_stride2( | |||
const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
size_t OH, size_t OW, size_t IC) { | |||
const size_t tail_step = IW - 2 * OW + IW; | |||
size_t width = OW >> 2; | |||
@@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter, | |||
"bne 2b \n" | |||
"3: \n" | |||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
"+r"(r4), "+r"(r5), "+r"(r6) | |||
: "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||
"+r"(r5), "+r"(r6) | |||
: "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | |||
"w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | |||
"w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | |||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||
"w"(_k48484848) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
"v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18"); | |||
"w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||
: "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18"); | |||
r0 += tail_step; | |||
r1 += tail_step; | |||
@@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
size_t N = OH * OW; | |||
#if MGB_ENABLE_DOT | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
@@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
} | |||
#else | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
@@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
return {nullptr, {part0, part1, part2}}; | |||
} | |||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
void ConvBiasImpl::AlgoS8MatrixMul::kimpl( | |||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
auto is_xcorr = !param.filter_meta.should_flip; | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
auto bundle = get_bundle(param); | |||
@@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
} else { | |||
if (is_xcorr) | |||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
img2col_stride<true>( | |||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
else | |||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
img2col_stride<false>( | |||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
} | |||
} | |||
{ | |||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
bundle.get_size(2)); | |||
Workspace workspace( | |||
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#if MGB_ENABLE_DOT | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||
if (cpuinfo_has_arm_neon_dot()) { | |||
@@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
} | |||
#else | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
bias); \ | |||
} \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||
} \ | |||
MIDOUT_END() | |||
DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
#endif | |||
@@ -12,8 +12,8 @@ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "S8MATMUL"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
bool usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override { | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
@@ -29,9 +29,10 @@ struct KernCaller; | |||
#if MGB_ENABLE_DOT | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 12> { | |||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||
static void run( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
@@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||
is_first_k); | |||
matmul_8x12x4::kern_8x12( | |||
packA, cur_packB, K, workspace, 12, is_first_k); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||
12>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>:: | |||
postprocess(bias, workspace, output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_8x12x4::kern_8x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | |||
@@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||
is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
matmul_8x12x4::kern_4x12( | |||
packA, cur_packB, K, workspace, 12, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
@@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x12x4::kern_4x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
@@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 4, 4> { | |||
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
Op op, const dt_int32* bias, dt_int32* workspace) { | |||
static void run( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 4; | |||
@@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||
4>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess( | |||
bias, workspace, output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||
4, is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
matmul_4x4x16::kern_4x4_remain( | |||
packA, cur_packB, K, workspace, 4, is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
@@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
@@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | |||
void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_s8_4x4_nobias_identity::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_s8_4x4_nobias_identity::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
@@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||
#if MGB_ENABLE_DOT | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | |||
void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_s8_8x12_nobias_identity::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | |||
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_s8_8x12_nobias_identity::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
@@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||
#endif | |||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||
const dt_int32* bias, dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
workspace); \ | |||
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ | |||
dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||
dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ | |||
} | |||
#define DEFINE_OP(_Op) \ | |||
@@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
#endif | |||
#undef DEFINE_OP | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||
scale_A* scale_B, scale_C); | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_qint8> op( \ | |||
scale_A* scale_B, scale_A* scale_B, scale_C); | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
#if MGB_ENABLE_DOT | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
#endif | |||
#undef DEFINE_OP | |||
@@ -20,43 +20,42 @@ namespace matmul { | |||
* | |||
* \name gemm_<type>_<block>_biasmode_nolinemode | |||
*/ | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||
false, true, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||
gemm_s8_4x4_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); | |||
#if MGB_ENABLE_DOT | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||
false, true, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||
gemm_s8_8x12_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); | |||
#endif | |||
} // namespace matmul | |||
@@ -13,13 +13,13 @@ | |||
#include "src/aarch64/conv_bias/int8/algos.h" | |||
#include "src/aarch64/conv_bias/quint8/algos.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
#include "src/fallback/convolution/opr_impl.h" | |||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||
#include "src/aarch64/conv_bias/fp16/algos.h" | |||
#include "src/aarch64/conv_bias/fp32/algos.h" | |||
#include "src/fallback/convolution/opr_impl.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
@@ -56,12 +56,10 @@ public: | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||
return m_direct_algos; | |||
} | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||
const { | |||
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const { | |||
return m_matmul_algos; | |||
} | |||
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
}; | |||
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
@@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
ConvBiasImpl::get_all_packed_algo() { | |||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
algo_pack().direct_algos().end()); | |||
algos.insert( | |||
algos.begin(), algo_pack().direct_algos().begin(), | |||
algo_pack().direct_algos().end()); | |||
//! We put matmul algos at the begin. Because matmul will get privilege when | |||
//! prefer return true. See | |||
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||
algo_pack().matmul_algos().end()); | |||
algos.insert( | |||
algos.begin(), algo_pack().matmul_algos().begin(), | |||
algo_pack().matmul_algos().end()); | |||
return std::move(algos); | |||
} | |||
@@ -9,8 +9,8 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/common/utils.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
@@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
size_t N = OH * OW; | |||
#if MGB_ENABLE_DOT | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
@@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | |||
} | |||
#else | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
part2 = megdnn::matmul::GemmInterleaved< \ | |||
@@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
return {nullptr, {part0, part1, part2}}; | |||
} | |||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( | |||
const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
auto is_xcorr = !param.filter_meta.should_flip; | |||
UNPACK_CONV_NCB_KERN_SIZES(param); | |||
auto bundle = get_bundle(param); | |||
@@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
} else { | |||
if (is_xcorr) | |||
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
img2col_stride<true>( | |||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
else | |||
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
FW, SH, SW); | |||
img2col_stride<false>( | |||
src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
} | |||
} | |||
{ | |||
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
bundle.get_size(2)); | |||
Workspace workspace( | |||
static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||
size_t M = OC; | |||
size_t K = IC * FH * FW; | |||
size_t N = OH * OW; | |||
#if MGB_ENABLE_DOT | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||
if (cpuinfo_has_arm_neon_dot()) { | |||
@@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||
} | |||
#else | |||
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
_bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved< \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
bias); \ | |||
} \ | |||
#define DISPATCH_GEMM_STRATEGY( \ | |||
_gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
_nonline_midout_enum) \ | |||
MIDOUT_BEGIN( \ | |||
megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||
_bias_midout_enum, _nonline_midout_enum) { \ | |||
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
gemm_interleaved(M, N, K, false, false, strategy); \ | |||
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||
} \ | |||
MIDOUT_END() | |||
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||
@@ -12,8 +12,8 @@ | |||
#pragma once | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
@@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "QU8MATMUL"; } | |||
bool usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
bool usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override { | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
@@ -14,8 +14,8 @@ | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
#include "src/arm_common/conv_bias/matmul_postprocess.h" | |||
using namespace megdnn; | |||
@@ -29,10 +29,10 @@ struct KernCaller; | |||
#if MGB_ENABLE_DOT | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 8, true> { | |||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
static void run( | |||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t B_INTERLEAVE = 8; | |||
@@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, zp_A, zp_B, zAB); | |||
matmul_8x8x4::kern_8x8( | |||
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
8>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||
postprocess(bias, workspace, output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4), | |||
zp_A, zp_B, zAB); | |||
matmul_8x8x4::kern_8x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(N - n, 4), zp_A, zp_B, zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
@@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
zp_A, zp_B, zAB); | |||
matmul_8x8x4::kern_4x8( | |||
packA, cur_packB, K, workspace, 8, is_first_k, | |||
std::min<size_t>(M - m, 4), zp_A, zp_B, zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
@@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zp_A, zp_B, | |||
zAB); | |||
matmul_8x8x4::kern_4x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||
zp_B, zAB); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
@@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
template <BiasMode bmode, typename Op> | |||
struct KernCaller<bmode, Op, 8, 8, false> { | |||
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
static void run( | |||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||
dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
megdnn_assert(is_first_k); | |||
constexpr size_t A_INTERLEAVE = 8; | |||
@@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
size_t n = 0; | |||
const dt_uint8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, zp_A, zp_B); | |||
matmul_8x8x8::kern_8x8( | |||
packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
8>::postprocess(bias, workspace, | |||
output, LDC, op); | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||
postprocess(bias, workspace, output, LDC, op); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(N - n, 4), | |||
zp_A, zp_B); | |||
matmul_8x8x8::kern_8x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
const dt_uint8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
zp_A, zp_B); | |||
matmul_8x8x8::kern_4x8( | |||
packA, cur_packB, K, workspace, 8, is_first_k, | |||
std::min<size_t>(M - m, 4), zp_A, zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
@@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4), zp_A, zp_B); | |||
matmul_8x8x8::kern_4x4( | |||
packA, cur_packB, K, workspace, 4, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||
zp_B); | |||
#define cb(m, n) \ | |||
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
bias, workspace, output, LDC, op); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
#undef cb | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
#if MGB_ENABLE_DOT | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | |||
void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_u8_8x8_dot_nobias_identity::pack_A( | |||
uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax); | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||
y0, ymax, k0, kmax); | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_u8_8x8_dot_nobias_identity::pack_B( | |||
uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
@@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { | |||
#endif | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | |||
void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, | |||
const dt_uint8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_u8_8x8_nodot_nobias_identity::pack_A( | |||
dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax, zA); | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax, zA); | |||
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||
} | |||
} | |||
void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool transpose) const { | |||
void gemm_u8_8x8_nodot_nobias_identity::pack_B( | |||
dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
if (transpose) { | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
k0, kmax, zB); | |||
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( | |||
out, in, ldin, x0, xmax, k0, kmax, zB); | |||
} else { | |||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||
zB); | |||
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); | |||
} | |||
} | |||
@@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { | |||
return 8 * 8 * sizeof(dt_int32); | |||
} | |||
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ | |||
_OP) \ | |||
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ | |||
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ | |||
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||
const dt_int32* bias, dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
workspace, zp_A, zp_B); \ | |||
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \ | |||
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \ | |||
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||
dt_int32* workspace) const { \ | |||
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
DEFINE_OP(_OP); \ | |||
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ | |||
zp_B); \ | |||
} | |||
#define DEFINE_OP(_Op) \ | |||
@@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
#undef DEFINE_OP | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||
scale_A* scale_B, scale_C, zp_C); | |||
#define DEFINE_OP(_Op) \ | |||
arm_common::_Op<dt_qint32, dt_quint8> op( \ | |||
scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); | |||
#if MGB_ENABLE_DOT | |||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||
FuseAddReluOp) | |||
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
#endif | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, | |||
AddOp) | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||
FuseAddReluOp) | |||
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
FuseAddHSwishOp) | |||
#undef DEFINE_OP | |||
#undef KERN | |||
@@ -16,46 +16,44 @@ namespace aarch64 { | |||
namespace matmul { | |||
#if MGB_ENABLE_DOT | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||
false, true, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, | |||
gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity); | |||
#endif | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||
false, true, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, | |||
gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -11,11 +11,11 @@ | |||
#include "src/common/handle_impl.h" | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/aarch64/handle.h" | |||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||
#include "src/aarch64/rotate/opr_impl.h" | |||
#include "src/aarch64/relayout/opr_impl.h" | |||
#include "src/aarch64/conv_bias/opr_impl.h" | |||
#include "src/aarch64/rotate/opr_impl.h" | |||
#include "src/aarch64/warp_perspective/opr_impl.h" | |||
namespace megdnn { | |||
@@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||
#pragma GCC diagnostic pop | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -14,20 +14,18 @@ | |||
namespace megdnn { | |||
namespace aarch64 { | |||
class HandleImpl: public arm_common::HandleImpl { | |||
public: | |||
HandleImpl(megcoreComputingHandle_t computing_handle, | |||
HandleType type = HandleType::AARCH64): | |||
arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||
{} | |||
class HandleImpl : public arm_common::HandleImpl { | |||
public: | |||
HandleImpl( | |||
megcoreComputingHandle_t computing_handle, | |||
HandleType type = HandleType::AARCH64) | |||
: arm_common::HandleImpl::HandleImpl(computing_handle, type) {} | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen | |||
@@ -21,9 +21,7 @@ namespace aarch64 { | |||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_F32K8X12X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -35,8 +33,7 @@ public: | |||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -48,9 +45,7 @@ public: | |||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_F32K4X16X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -61,9 +56,7 @@ public: | |||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -73,8 +66,7 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
public: | |||
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
m_handle_type = Handle::HandleType::AARCH64; | |||
@@ -85,9 +77,7 @@ public: | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_F16_K8X24X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -98,9 +88,7 @@ public: | |||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -115,12 +103,8 @@ public: | |||
#if MGB_ENABLE_DOT | |||
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
@@ -130,12 +114,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
@@ -147,8 +127,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -163,9 +142,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
@@ -178,9 +155,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
@@ -192,9 +167,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
@@ -207,9 +180,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
@@ -222,8 +193,7 @@ public: | |||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -238,12 +208,9 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -257,12 +224,9 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -276,8 +240,7 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -292,9 +255,7 @@ public: | |||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
@@ -306,9 +267,7 @@ public: | |||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -321,12 +280,8 @@ public: | |||
#if MGB_ENABLE_DOT | |||
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
@@ -336,8 +291,7 @@ public: | |||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -352,9 +306,7 @@ public: | |||
#endif | |||
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
} | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
@@ -16,11 +16,11 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||
true, hgemm_8x24); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||
false, true, gemm_nopack_f16_8x8); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -9,8 +9,8 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/fp16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
@@ -21,8 +21,9 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
void kern_8x1( | |||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
LDB *= sizeof(dt_float16); | |||
asm volatile( | |||
".arch armv8.2-a+fp16\n" | |||
@@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
"memory"); | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
// |v23[0-7]| |v27[0-7]| | |||
// +--------+ +--------+ | |||
// Accumulator | |||
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
void kern_8x4( | |||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
//! LDB means number of elements in one block in B. we will read 24 numbers | |||
//! first. so minus 24 * 2 bytes here. | |||
LDB = (LDB - 24) * sizeof(dt_float16); | |||
@@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
// | v7[0-7]| |v31[0-7]| | |||
// +--------+ +--------+ | |||
// Accumulator | |||
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
void kern_8x8( | |||
const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
dt_float16* output) { | |||
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | |||
//! here. | |||
LDB = (LDB - 32) * sizeof(dt_float16); | |||
@@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29", | |||
"v30", "v31", "cc", "memory"); | |||
} | |||
} // anonymous namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | |||
void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
const dt_float16* B, size_t LDB, dt_float16* C, | |||
size_t LDC, size_t M, size_t K, size_t N, | |||
const dt_float16*, void*, bool trA, | |||
bool trB) const { | |||
void gemm_nopack_f16_8x8::kern( | |||
const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, | |||
size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, | |||
bool trB) const { | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 8; | |||
@@ -17,21 +17,23 @@ | |||
namespace megdnn { | |||
namespace aarch64 { | |||
MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||
size_t K, size_t LDA, const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packA_n( | |||
const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||
const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||
size_t K, size_t LDA, const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packA_t( | |||
const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||
const float* alpha); | |||
MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||
size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_packB_n( | |||
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||
size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_packB_t( | |||
const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||
MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||
size_t LDC, size_t M, size_t N, size_t K, | |||
int type, const float* beta); | |||
MEGDNN_NOINLINE void sgemm_kernel12x8( | |||
const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N, | |||
size_t K, int type, const float* beta); | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -12,7 +12,6 @@ | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_general_4x16 { | |||
@@ -39,8 +38,9 @@ namespace matmul_general_4x16 { | |||
// +--+ - - - - +--------+--------+--------+--------+ | |||
// | |||
// Accumulator | |||
void kern_4x16(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, int m_remain) { | |||
void kern_4x16( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||
"x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||
// +--+--+ - - - - +--------+ | |||
// | |||
// Accumulator | |||
void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||
int LDC, bool is_first_k, int m_remain, int n_remain) { | |||
void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||
STORE_LINE("6", "2") \ | |||
STORE_LINE("7", "3") \ | |||
"105:\n" | |||
// clang-format on | |||
asm volatile( | |||
// load accumulator C | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" LOAD_C | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v4.16b, v4.16b, v4.16b\n" | |||
"eor v5.16b, v5.16b, v5.16b\n" | |||
"eor v6.16b, v6.16b, v6.16b\n" | |||
"eor v7.16b, v7.16b, v7.16b\n" | |||
"2: \n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"cmp %w[K], #0\n" | |||
"beq 4f\n" | |||
"3:\n" | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"bne 3b\n" | |||
"4:\n" | |||
"cmp %w[oddk], #1\n" | |||
"beq 5f\n" | |||
// Even tail | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"b 6f\n" | |||
// odd tail | |||
"5:\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||
"x2", "x3", "x10", "cc", "memory"); | |||
// clang-format on | |||
asm volatile( | |||
// load accumulator C | |||
"add x1, x0, %x[LDC]\n" | |||
"add x2, x1, %x[LDC]\n" | |||
"add x3, x2, %x[LDC]\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 1f\n" LOAD_C | |||
"b 2f\n" | |||
"1:\n" | |||
"eor v4.16b, v4.16b, v4.16b\n" | |||
"eor v5.16b, v5.16b, v5.16b\n" | |||
"eor v6.16b, v6.16b, v6.16b\n" | |||
"eor v7.16b, v7.16b, v7.16b\n" | |||
"2: \n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"cmp %w[K], #0\n" | |||
"beq 4f\n" | |||
"3:\n" | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"subs %w[K], %w[K], #1\n" | |||
"bne 3b\n" | |||
"4:\n" | |||
"cmp %w[oddk], #1\n" | |||
"beq 5f\n" | |||
// Even tail | |||
"ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
"ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"fmla v4.4s, v3.4s, v1.s[0]\n" | |||
"fmla v5.4s, v3.4s, v1.s[1]\n" | |||
"fmla v6.4s, v3.4s, v1.s[2]\n" | |||
"fmla v7.4s, v3.4s, v1.s[3]\n" | |||
"b 6f\n" | |||
// odd tail | |||
"5:\n" | |||
"fmla v4.4s, v2.4s, v0.s[0]\n" | |||
"fmla v5.4s, v2.4s, v0.s[1]\n" | |||
"fmla v6.4s, v2.4s, v0.s[2]\n" | |||
"fmla v7.4s, v2.4s, v0.s[3]\n" | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10", | |||
"cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
int ymax, int k0, int kmax) { | |||
void sgemm_4x16_pack_A_n( | |||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
float zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(float) * 4); | |||
constexpr int PACK_SIZE = 4*4; | |||
constexpr int PACK_SIZE = 4 * 4; | |||
int y = y0; | |||
for (; y + 3 < ymax; y += 4) { | |||
// printf("main loop pack_a_n %p \n",outptr); | |||
// printf("main loop pack_a_n %p \n",outptr); | |||
const float* inptr0 = inptr + y * ldin + k0; | |||
const float* inptr1 = inptr0 + ldin; | |||
const float* inptr2 = inptr1 + ldin; | |||
@@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
} | |||
} | |||
void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
int xmax, int k0, int kmax) { | |||
void sgemm_4x16_pack_A_t( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize4 = (ksize << 2); | |||
float* outptr_base = out; | |||
@@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
auto outptr = outptr_base; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
@@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
} | |||
} | |||
void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
void sgemm_4x16_pack_B_n( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize16 = ksize * 16; | |||
int ksize4 = (ksize << 2); | |||
@@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
auto outptr = outptr_base; | |||
for (; x + 16 <= xmax; x += 16) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize16; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
@@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
} | |||
} | |||
void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
void sgemm_4x16_pack_B_t( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||
float* outptr = out; | |||
const float* inptr = in; | |||
float zerobuff[4]; | |||
@@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
int x = (kmax - k0); | |||
for (; x > 3; x -= 4) { | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||
64); | |||
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64); | |||
outptr_inner += 64; | |||
} | |||
for (; x > 0; x--) { | |||
@@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
switch ((y + 3) - ymax) { | |||
/* Everything falls through in here */ | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
} | |||
} | |||
} // matmul_general_4x16 | |||
} // aarch64 | |||
} // megdnn | |||
} // namespace matmul_general_4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -43,8 +43,9 @@ struct matmul_general_8x12 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -306,14 +307,13 @@ struct matmul_general_8x12 { | |||
"6:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
"x6", "x7", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -348,9 +348,9 @@ struct matmul_general_8x12 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -520,13 +520,12 @@ struct matmul_general_8x12 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -557,9 +556,9 @@ struct matmul_general_8x12 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int m_remain) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -717,13 +716,12 @@ struct matmul_general_8x12 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "x2", "x3", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"x2", "x3", "x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -754,9 +752,9 @@ struct matmul_general_8x12 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -895,20 +893,21 @@ struct matmul_general_8x12 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
[m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
"x3", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
"x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_A_n( | |||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
float zerobuff[8]; | |||
std::memset(zerobuff, 0, sizeof(float) * 8); | |||
constexpr int PACK_SIZE_32 = 4 * 8; | |||
@@ -933,8 +932,9 @@ struct matmul_general_8x12 { | |||
prefetch_2x(inptr7); | |||
int x = (kmax - k0); | |||
for (; x > 3; x -= 4) { | |||
transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
transpose_8x4_1_s( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
outptr += PACK_SIZE_32; | |||
} | |||
for (; x > 0; x--) { | |||
@@ -1004,8 +1004,8 @@ struct matmul_general_8x12 { | |||
} | |||
} | |||
static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_A_t( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize8 = (ksize << 3); | |||
int ksize4 = (ksize << 2); | |||
@@ -1028,20 +1028,17 @@ struct matmul_general_8x12 { | |||
auto outptr = outptr_base; | |||
for (; x + 8 <= xmax; x += 8) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize8; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||
xmax - x); | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 4 * 8; | |||
outptr_base4 += 4 * 4; | |||
@@ -1071,8 +1068,8 @@ struct matmul_general_8x12 { | |||
} | |||
} | |||
static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_B_n( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int ksize = kmax - k0; | |||
int ksize12 = ksize * 12; | |||
int ksize4 = (ksize << 2); | |||
@@ -1095,20 +1092,17 @@ struct matmul_general_8x12 { | |||
auto outptr = outptr_base; | |||
for (; x + 12 <= xmax; x += 12) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize12; | |||
} | |||
outptr = outptr_base4; | |||
for (; x + 4 <= xmax; x += 4) { | |||
auto outptr_interleave = outptr; | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
outptr_interleave); | |||
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||
xmax - x); | |||
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
} | |||
outptr_base += 12 * 4; | |||
outptr_base4 += 4 * 4; | |||
@@ -1138,8 +1132,8 @@ struct matmul_general_8x12 { | |||
} | |||
} | |||
static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_B_t( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||
float* outptr = out; | |||
const float* inptr = in; | |||
float zerobuff[12]; | |||
@@ -1172,9 +1166,9 @@ struct matmul_general_8x12 { | |||
prefetch_2x(inptr11); | |||
int x = (kmax - k0); | |||
for (; x > 3; x -= 4) { | |||
transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, inptr8, inptr9, | |||
inptr10, inptr11, outptr); | |||
transpose_12x4_1_s( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
inptr8, inptr9, inptr10, inptr11, outptr); | |||
outptr += 48; | |||
} | |||
for (; x > 0; x--) { | |||
@@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 { | |||
"6:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||
"x11", "x12", "x13", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
} | |||
@@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"x8", "x9", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", | |||
"x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int m_remain) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", | |||
"x22", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
[m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
"x3", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
"x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 { | |||
"6:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
"x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||
"x11", "x12", "x13", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
} | |||
@@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
"v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
"v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int m_remain) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
int oddk = (K & 1); | |||
@@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 { | |||
"6:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
[n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
[m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
"x3", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
"x10", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -307,10 +308,10 @@ struct matmul_mk4_8x12 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [output1] "+r"(output1) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -340,9 +341,9 @@ struct matmul_mk4_8x12 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -500,8 +501,8 @@ struct matmul_mk4_8x12 { | |||
[output0] "+r"(output0), [output1] "+r"(output1), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
@@ -531,8 +532,9 @@ struct matmul_mk4_8x12 { | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -669,9 +671,9 @@ struct matmul_mk4_8x12 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -697,9 +699,9 @@ struct matmul_mk4_8x12 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -818,15 +820,15 @@ struct matmul_mk4_8x12 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
} | |||
static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_A( | |||
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int PACK_SIZE_32 = 4 * 8; | |||
@@ -855,8 +857,8 @@ struct matmul_mk4_8x12 { | |||
} | |||
} | |||
static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, | |||
int xmax, int k0, int kmax) { | |||
static void sgemm_8x12_pack_B( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
float tmpbuff[16] = {0.0f}; | |||
@@ -886,8 +888,7 @@ struct matmul_mk4_8x12 { | |||
outptr += ksize4; | |||
} | |||
if (x < xmax) { | |||
std::memcpy(tmpbuff, inptr, | |||
sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||
std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||
auto outptr_interleave = outptr; | |||
const float* tmp_ptr = &tmpbuff[0]; | |||
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | |||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [output1] "+r"(output1) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||
"x11", "x12", "x13", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 { | |||
[output0] "+r"(output0), [output1] "+r"(output1), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "x8", "x9", "x10", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
@@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 { | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
@@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 { | |||
// +--+ --- - +--------+--------+--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [output1] "+r"(output1) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
"v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||
"x11", "x12", "x13", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_8x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_8x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
float* output0 = output; | |||
@@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 { | |||
[output0] "+r"(output0), [output1] "+r"(output1), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "x8", "x9", "x10", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
@@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 { | |||
// | |||
// Accumulator | |||
static void kern_4x12(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k) { | |||
static void kern_4x12( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
"x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 { | |||
// +--+ --- - +--------+ | |||
// | |||
// Accumulator | |||
static void kern_4x4(const float* packA, const float* packB, int K, | |||
float* output, int LDC, bool is_first_k, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const float* packA, const float* packB, int K, float* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
const float* a_ptr = packA; | |||
const float* b_ptr = packB; | |||
@@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 { | |||
[is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
[output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
#undef LOAD_C | |||
#undef STORE_C | |||
@@ -10,6 +10,7 @@ | |||
* implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | |||
@@ -17,44 +18,40 @@ | |||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | |||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | |||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
using namespace aarch64::matmul; | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | |||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
int k0, int kmax, bool transpose_A) const { | |||
void sgemm_4x16::pack_A( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose_A) const { | |||
if (transpose_A) { | |||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
int k0, int kmax, bool transpose_B) const { | |||
void sgemm_4x16::pack_B( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose_B) const { | |||
if (transpose_B) { | |||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
const float*, float*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
void sgemm_4x16::kern( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k, const float*, float*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
matmul_general_4x16::kern_4x16( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K16; | |||
} | |||
@@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | |||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
int k0, int kmax, bool transpose_A) const { | |||
void sgemm_8x12::pack_A( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose_A) const { | |||
if (transpose_A) { | |||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
int k0, int kmax, bool transpose_B) const { | |||
void sgemm_8x12::pack_B( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose_B) const { | |||
if (transpose_B) { | |||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
template <typename gemm_class> | |||
static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k) { | |||
static inline void sgemm_8x12_helper( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k) { | |||
constexpr size_t A_INTERLEAVE = 8; | |||
constexpr size_t A_INTERLEAVE4 = 4; | |||
constexpr size_t B_INTERLEAVE = 12; | |||
@@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
} | |||
for (; n < N; n += 4) { | |||
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
gemm_class::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
size_t n = 0; | |||
const float* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
gemm_class::kern_4x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
gemm_class::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
} | |||
} | |||
void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
const float*, float*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
void sgemm_8x12::kern( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k, const float*, float*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
#if !MGB_ENABLE_CPUINFO | |||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
is_first_k); | |||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||
#else | |||
auto arch = cpuinfo_get_current_core()->uarch; | |||
#ifdef __IN_TEE_ENV__ | |||
arch = cpuinfo_uarch_unknown; | |||
#endif | |||
if (arch == cpuinfo_uarch_cortex_a53) { | |||
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||
LDC, is_first_k); | |||
sgemm_8x12_helper<matmul_general_8x12_a53>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} else if (arch == cpuinfo_uarch_cortex_a55) { | |||
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||
LDC, is_first_k); | |||
sgemm_8x12_helper<matmul_general_8x12_a55>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} else { | |||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
is_first_k); | |||
sgemm_8x12_helper<matmul_general_8x12>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} | |||
#endif | |||
} | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | |||
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose_A) const { | |||
void sgemm_mk4_8x12::pack_A( | |||
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose_A) const { | |||
megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | |||
matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose_B) const { | |||
void sgemm_mk4_8x12::pack_B( | |||
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose_B) const { | |||
megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | |||
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
template <typename gemm_name> | |||
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k) { | |||
static inline void sgemm_mk4_8x12_helper( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k) { | |||
const int K12 = K * 12; | |||
const int K8 = K * 8; | |||
const int K4 = K * 4; | |||
@@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
} | |||
for (; n < N; n += 4) { | |||
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
gemm_name::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4 * PACK_C_SIZE; | |||
cur_packB += K4; | |||
} | |||
@@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
gemm_name::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4 * PACK_C_SIZE; | |||
cur_packB += K4; | |||
} | |||
packA += K4; | |||
} | |||
} | |||
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||
size_t N, size_t K, float* C, size_t LDC, | |||
bool is_first_k, const float*, float*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
void sgemm_mk4_8x12::kern( | |||
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
size_t LDC, bool is_first_k, const float*, float*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
A_dtype.enumv() == DTypeEnum::Float32); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
#if !MGB_ENABLE_CPUINFO | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
is_first_k); | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||
#else | |||
auto arch = cpuinfo_get_current_core()->uarch; | |||
#ifdef __IN_TEE_ENV__ | |||
arch = cpuinfo_uarch_unknown; | |||
#endif | |||
if (arch == cpuinfo_uarch_cortex_a53) { | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||
LDC, is_first_k); | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} else if (arch == cpuinfo_uarch_cortex_a55) { | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||
LDC, is_first_k); | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} else { | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
is_first_k); | |||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>( | |||
packA, packB, M, N, K, C, LDC, is_first_k); | |||
} | |||
#endif | |||
} | |||
@@ -15,17 +15,14 @@ | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||
sgemm_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||
sgemm_4x16); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, | |||
sgemm_mk4_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||
sgemm_nopack_4x16); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -20,8 +20,8 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
void kern_4x1( | |||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
LDB *= sizeof(float); | |||
asm volatile( | |||
"subs %w[K], %w[K], #4\n" | |||
@@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||
"memory"); | |||
: "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
// +--------+ - - - - -+--------+ | |||
// Accumulator | |||
void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
void kern_4x4( | |||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
//! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | |||
//! here. | |||
LDB = (LDB - 12) * sizeof(float); | |||
@@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
"v18", "v19", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||
"v19", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
// +--------+ - - - - -+--------+ | |||
// Accumulator | |||
void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
float* output) { | |||
void kern_4x8( | |||
const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | |||
//! here. | |||
LDB = (LDB - 24) * sizeof(float); | |||
@@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
"v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||
"v27", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
// +--------+ | |||
// Accumulator | |||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||
float* output) { | |||
void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) { | |||
//! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | |||
//! here. | |||
LDB = (LDB - 56) * sizeof(float); | |||
@@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
"v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
} | |||
} // namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | |||
void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
size_t LDB, float* C, size_t LDC, size_t M, | |||
size_t K, size_t N, const float*, void*, bool trA, | |||
bool trB) const { | |||
void sgemm_nopack_4x16::kern( | |||
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||
constexpr static size_t MB = 4; | |||
constexpr static size_t KB = 4; | |||
constexpr static size_t NB = 16; | |||
@@ -46,8 +46,9 @@ namespace matmul_12x8x1 { | |||
* Accumulator | |||
*/ | |||
static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_12x8( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
"stp q25, q26, [x9]\n" | |||
"stp q27, q28, [x10]\n" | |||
"stp q29, q30, [x11]\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[output] "+r"(output) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", | |||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", | |||
"x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", | |||
"cc", "memory"); | |||
: "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5", | |||
"x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x8( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
"stp q17, q18, [x5]\n" | |||
"stp q19, q20, [x6]\n" | |||
"stp q21, q22, [x7]\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[output] "+r"(output) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
: | |||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", | |||
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4", | |||
"x5", "x6", "x7", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
size_t m_remain) { | |||
static void kern_4x8( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t m_remain) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
[m_remain] "+r"(m_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"cc", "memory"); | |||
: "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
size_t n_remain) { | |||
static void kern_12x4( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t n_remain) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||
[outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), | |||
[x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
[outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||
[outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "cc", "memory"); | |||
: "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
"v16", "v17", "v18", "v19", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
size_t n_remain) { | |||
static void kern_8x4( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t n_remain) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
[n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
"cc", "memory"); | |||
: "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
size_t n_remain) { | |||
static void kern_4x4( | |||
const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||
const int16_t* a_ptr = packA; | |||
const int16_t* b_ptr = packB; | |||
@@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
[m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
@@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s16_12x8x1_pack_A_n( | |||
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int16_t zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(int16_t) * 4); | |||
@@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
int K = kmax - k0; | |||
for (; K > 3; K -= 4) { | |||
interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, inptr8, inptr9, inptr10, | |||
inptr11, outptr); | |||
interleave_12x1_4_h( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
inptr8, inptr9, inptr10, inptr11, outptr); | |||
} | |||
if (K > 0) { | |||
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||
outptr, 1, K); | |||
interleave_12( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
inptr8, inptr9, inptr10, inptr11, outptr, 1, K); | |||
} | |||
} | |||
@@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
int K = kmax - k0; | |||
for (; K > 7; K -= 8) { | |||
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
interleave_8x1_8_h( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 1, K); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 1, K); | |||
} | |||
} | |||
@@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
} | |||
} | |||
static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||
int ldin, int x0, int xmax, | |||
int k0, int kmax) { | |||
static void gemm_s16_12x8x1_transpose_pack_A_n( | |||
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
const int ksize = kmax - k0; | |||
const int ksize4 = ksize * 4; | |||
const int ksize8 = ksize4 * 2; | |||
@@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||
} | |||
} | |||
static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s16_12x8x1_pack_B_n( | |||
int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
const int ksize = kmax - k0; | |||
const int ksize4 = ksize * 4; | |||
const int ksize8 = ksize4 * 2; | |||
@@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||
} | |||
} | |||
static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
const int16_t* inptr, int ldin, | |||
int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s16_12x8x1_transpose_pack_B_n( | |||
int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int16_t zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(int16_t) * 4); | |||
@@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
int K = kmax - k0; | |||
for (; K > 7; K -= 8) { | |||
interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
interleave_8x1_8_h( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 1, K); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 1, K); | |||
} | |||
} | |||
@@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -22,39 +22,37 @@ using namespace aarch64::matmul; | |||
///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | |||
void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s16_12x8x1::pack_A( | |||
dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||
y0, ymax, k0, kmax); | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||
k0, kmax); | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_A_n( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||
int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s16_12x8x1::pack_B( | |||
dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
size_t M, size_t N, size_t K, dt_int32* C, | |||
size_t LDC, bool is_first_k, const dt_int32*, | |||
dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||
C_dtype.enumv() == DTypeEnum::Int32), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s16_12x8x1::kern( | |||
const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int16 && | |||
C_dtype.enumv() == DTypeEnum::Int32), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
size_t n = 0; | |||
const dt_int16* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_12x8x1::kern_12x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
const dt_int16* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_12x8x1::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
const dt_int16* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4)); | |||
matmul_12x8x1::kern_4x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
matmul_12x8x1::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -16,11 +16,11 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||
gemm_s16_12x8x1); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||
true, gemm_nopack_s16_8x8); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -9,8 +9,8 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
@@ -20,8 +20,9 @@ using namespace aarch64::matmul; | |||
namespace { | |||
void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
void kern_8x1( | |||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
//! here. | |||
LDB *= sizeof(dt_int16); | |||
@@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
// | v31[0-7]| |v23[0-3]| | |||
// +---------+ +--------+ | |||
// Accumulator | |||
void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
void kern_8x4( | |||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
//! here. | |||
LDB = (LDB - 24) * sizeof(dt_int16); | |||
@@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
"memory"); | |||
} | |||
// Overview of register layout: | |||
@@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
// | v7[0-7]| |v30[0-3]|v31[0-3]| | |||
// +--------+ +--------+--------+ | |||
// Accumulator | |||
void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
void kern_8x8( | |||
const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
dt_int32* output) { | |||
//! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||
//! here. | |||
LDB = (LDB - 48) * sizeof(dt_int16); | |||
@@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[output] "+r"(output), [LDB] "+r"(LDB) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"cc", "memory"); | |||
} | |||
} // anonymous namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | |||
void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||
size_t K, size_t N, const dt_int32*, void*, | |||
bool trA, bool trB) const { | |||
void gemm_nopack_s16_8x8::kern( | |||
const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, | |||
size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, | |||
bool trB) const { | |||
constexpr static size_t MB = 8; | |||
constexpr static size_t KB = 8; | |||
constexpr static size_t NB = 8; | |||
@@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 { | |||
* Accumulator | |||
*/ | |||
static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void s4_kern_8x8_remain( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packA; | |||
@@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
"dup v5.8b,v20.b[5]\n" | |||
"dup v6.8b,v20.b[6]\n" | |||
"dup v7.8b,v20.b[7]\n" | |||
"ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
"dup v8.8b,v20.b[8]\n" | |||
@@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
STORE_C | |||
: | |||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain), | |||
[n_remain] "+r"( | |||
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void s4_kern_8x8( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
// clang-format off | |||
// clang-format off | |||
#define LOAD_C_8 \ | |||
"ld1 {v24.8h}, [x0], #16\n" \ | |||
@@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"st1 {v28.8h}, [x4], #16\n" \ | |||
"st1 {v29.8h}, [x5], #16\n" \ | |||
"st1 {v30.8h}, [x6], #16\n" \ | |||
"st1 {v31.8h}, [x7], #16\n" \ | |||
"st1 {v31.8h}, [x7], #16\n" | |||
// clang-format on | |||
// clang-format on | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
"add x1, x0, %x[LDC]\n" | |||
@@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | |||
"PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | |||
"1:\n" | |||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
// "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
// "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
"dup v0.8b,v20.b[0]\n" | |||
"ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||
"dup v1.8b,v20.b[1]\n" | |||
@@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"dup v5.8b,v20.b[5]\n" | |||
"dup v6.8b,v20.b[6]\n" | |||
"dup v7.8b,v20.b[7]\n" | |||
"dup v8.8b,v20.b[8]\n" | |||
"smlal v24.8h, v0.8b, v16.8b\n" | |||
@@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
STORE_C_8 | |||
: | |||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
[ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
[a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain), | |||
[n_remain] "+r"( | |||
n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
#undef STORE_C | |||
} | |||
//packa | |||
static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
// packa | |||
static void gemm_s4x4x16_8x8x8_transpose_pack( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[8]; | |||
int8_t tmpbuff0[8]; | |||
int8_t tmpbuff1[8]; | |||
@@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
prefetch_2x(inptr5); | |||
prefetch_2x(inptr6); | |||
prefetch_2x(inptr7); | |||
int K = (kmax - k0)/2; | |||
int K = (kmax - k0) / 2; | |||
//! read 4 * 16 in each row | |||
for (; K > 3; K -= 4) { | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
transpose_4x8_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
std::memcpy(tmpbuff0,inptr0,K); | |||
std::memcpy(tmpbuff1,inptr1,K); | |||
std::memcpy(tmpbuff2,inptr2,K); | |||
std::memcpy(tmpbuff3,inptr3,K); | |||
std::memcpy(tmpbuff4,inptr4,K); | |||
std::memcpy(tmpbuff5,inptr5,K); | |||
std::memcpy(tmpbuff6,inptr6,K); | |||
std::memcpy(tmpbuff7,inptr7,K); | |||
std::memcpy(tmpbuff0, inptr0, K); | |||
std::memcpy(tmpbuff1, inptr1, K); | |||
std::memcpy(tmpbuff2, inptr2, K); | |||
std::memcpy(tmpbuff3, inptr3, K); | |||
std::memcpy(tmpbuff4, inptr4, K); | |||
std::memcpy(tmpbuff5, inptr5, K); | |||
std::memcpy(tmpbuff6, inptr6, K); | |||
std::memcpy(tmpbuff7, inptr7, K); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
@@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
transpose_4x8_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
} | |||
for (; y < ymax; y += 8) { | |||
@@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
const int8_t* inptr6 = inptr5 + ldin; | |||
const int8_t* inptr7 = inptr6 + ldin; | |||
int K = (kmax - k0)/2; | |||
int K = (kmax - k0) / 2; | |||
//! read 4 * 16 in each row | |||
for (; K > 3; K -= 4) { | |||
if (y + 7 >= ymax) { | |||
switch (y + 7 - ymax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
transpose_4x8_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
if (y + 7 >= ymax) { | |||
switch (y + 7 - ymax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
} | |||
} | |||
std::memcpy(tmpbuff0,inptr0,K); | |||
std::memcpy(tmpbuff1,inptr1,K); | |||
std::memcpy(tmpbuff2,inptr2,K); | |||
std::memcpy(tmpbuff3,inptr3,K); | |||
std::memcpy(tmpbuff4,inptr4,K); | |||
std::memcpy(tmpbuff5,inptr5,K); | |||
std::memcpy(tmpbuff6,inptr6,K); | |||
std::memcpy(tmpbuff7,inptr7,K); | |||
std::memcpy(tmpbuff0, inptr0, K); | |||
std::memcpy(tmpbuff1, inptr1, K); | |||
std::memcpy(tmpbuff2, inptr2, K); | |||
std::memcpy(tmpbuff3, inptr3, K); | |||
std::memcpy(tmpbuff4, inptr4, K); | |||
std::memcpy(tmpbuff5, inptr5, K); | |||
std::memcpy(tmpbuff6, inptr6, K); | |||
std::memcpy(tmpbuff7, inptr7, K); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
@@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
inptr5 = tmpbuff5; | |||
inptr6 = tmpbuff6; | |||
inptr7 = tmpbuff7; | |||
transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
inptr5, inptr6, inptr7, outptr); | |||
transpose_4x8_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
} | |||
} | |||
//packb | |||
static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
// packb | |||
static void gemm_s4x4x16_8x8x8_interleave_pack( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[8]; | |||
int8_t tmpbuff0[8]; | |||
int8_t tmpbuff1[8]; | |||
@@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||
std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||
const int ksize = kmax - k0; | |||
const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||
const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4 | |||
int8_t* outptr = out; | |||
int8_t* outptr_interleave = nullptr; | |||
@@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
int8_t* outptr_inner = outptr; | |||
for (; x + 3 < xmax; x += 4) { | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x4_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
if (x < xmax) { | |||
int remainx = xmax - x; | |||
std::memcpy(tmpbuff0,inptr0,remainx); | |||
std::memcpy(tmpbuff1,inptr1,remainx); | |||
std::memcpy(tmpbuff2,inptr2,remainx); | |||
std::memcpy(tmpbuff3,inptr3,remainx); | |||
std::memcpy(tmpbuff4,inptr4,remainx); | |||
std::memcpy(tmpbuff5,inptr5,remainx); | |||
std::memcpy(tmpbuff6,inptr6,remainx); | |||
std::memcpy(tmpbuff7,inptr7,remainx); | |||
std::memcpy(tmpbuff0, inptr0, remainx); | |||
std::memcpy(tmpbuff1, inptr1, remainx); | |||
std::memcpy(tmpbuff2, inptr2, remainx); | |||
std::memcpy(tmpbuff3, inptr3, remainx); | |||
std::memcpy(tmpbuff4, inptr4, remainx); | |||
std::memcpy(tmpbuff5, inptr5, remainx); | |||
std::memcpy(tmpbuff6, inptr6, remainx); | |||
std::memcpy(tmpbuff7, inptr7, remainx); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
@@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
inptr7 = tmpbuff7; | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x4_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
outptr += 64; | |||
@@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
break; | |||
} | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x4_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
if (x < xmax) { | |||
@@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
} | |||
int remainx = xmax - x; | |||
outptr_interleave = outptr_inner; | |||
std::memcpy(tmpbuff0,inptr0,remainx); | |||
std::memcpy(tmpbuff1,inptr1,remainx); | |||
std::memcpy(tmpbuff2,inptr2,remainx); | |||
std::memcpy(tmpbuff3,inptr3,remainx); | |||
std::memcpy(tmpbuff4,inptr4,remainx); | |||
std::memcpy(tmpbuff5,inptr5,remainx); | |||
std::memcpy(tmpbuff6,inptr6,remainx); | |||
std::memcpy(tmpbuff7,inptr7,remainx); | |||
std::memcpy(tmpbuff0, inptr0, remainx); | |||
std::memcpy(tmpbuff1, inptr1, remainx); | |||
std::memcpy(tmpbuff2, inptr2, remainx); | |||
std::memcpy(tmpbuff3, inptr3, remainx); | |||
std::memcpy(tmpbuff4, inptr4, remainx); | |||
std::memcpy(tmpbuff5, inptr5, remainx); | |||
std::memcpy(tmpbuff6, inptr6, remainx); | |||
std::memcpy(tmpbuff7, inptr7, remainx); | |||
inptr0 = tmpbuff0; | |||
inptr1 = tmpbuff1; | |||
inptr2 = tmpbuff2; | |||
@@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
inptr7 = tmpbuff7; | |||
outptr_interleave = outptr_inner; | |||
interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x4_1_b_with_shift( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr_inner += ksize8; | |||
} | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace matmul_s4_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -10,9 +10,9 @@ | |||
* implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
@@ -23,39 +23,38 @@ using namespace aarch64::matmul; | |||
// ===========================gemm_s4x4x16_s4_8x8x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | |||
void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s4x4x16_s4_8x8x8::pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||
out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||
out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s4x4x16_s4_8x8x8::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s4x4x16_s4_8x8x8::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
matmul_s4_4x4x16::s4_kern_8x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
B_INTERLEAVE); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
matmul_s4_4x4x16::s4_kern_8x8_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
@@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
matmul_s4_4x4x16::s4_kern_8x8_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
@@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -17,8 +17,8 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||
gemm_s4x4x16_s4_8x8x8); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -51,8 +51,9 @@ namespace matmul_4x4x16 { | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
); | |||
} | |||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
int m_remain, int n_remain) { | |||
static void kern_4x4_remain( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
megdnn_assert(K > 0); | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
@@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[output] "+r"(output), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), | |||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, int kmax) { | |||
static void gemm_s8_4x4_pack_A_n( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
} | |||
} | |||
static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8_4x4_pack_B_n( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
inptr0 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
transpose_4x16_1_b_helper( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
@@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
inptr0 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||
* Accumulator | |||
*/ | |||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x8( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"stp q18, q19, [x5]\n" | |||
"stp q20, q21, [x6]\n" | |||
"stp q22, q23, [x7]\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[output] "+r"(output) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v26", "v27", "x1", | |||
"x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"cc", "memory"); | |||
} | |||
/** | |||
@@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
size_t n_remain) { | |||
static void kern_8x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t n_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
[n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, | |||
size_t m_remain) { | |||
static void kern_4x8( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t m_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
[outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
size_t n_remain) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
[x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | |||
@@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax) { | |||
static void gemm_s8_8x8_pack_A_n( | |||
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
int K = kmax - k0; | |||
for (; K > 15; K -= 16) { | |||
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
interleave_8x8_2_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 8, K); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 8, K); | |||
} | |||
} | |||
@@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
} | |||
} | |||
static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax) { | |||
static void gemm_s8_8x8_transpose_pack_A_n( | |||
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
transpose_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
outptr += ksize8; | |||
} | |||
@@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
} | |||
} | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 4, 4); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 4, 4); | |||
outptr += ksize4; | |||
} | |||
@@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
} | |||
} | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 4, xmax - x); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 4, xmax - x); | |||
} | |||
outptr_base += 8 * 8; | |||
@@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
} | |||
} | |||
static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8_8x8_pack_B_n( | |||
int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
} | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr += ksize8; | |||
} | |||
@@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr_interleave, 4, 4); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave, 4, 4); | |||
outptr += ksize4; | |||
} | |||
@@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr_interleave, 4, xmax - x); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave, 4, xmax - x); | |||
} | |||
outptr_base += 8 * 8; | |||
@@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
} | |||
} | |||
static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8_8x8_transpose_pack_B_n( | |||
int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
constexpr int interleave4 = 32; | |||
@@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
int K = kmax - k0; | |||
for (; K > 7; K -= 8) { | |||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
transpose_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
outptr += interleave8; | |||
} | |||
if (K > 0) { | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 8, K); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 8, K); | |||
outptr += interleave8; | |||
} | |||
} | |||
@@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 { | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, bool is_first_k) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||
bool is_first_k) { | |||
K = div_ceil(K, 16); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"6:\n" | |||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[k] "+r"(K), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"cc", "memory"); | |||
} | |||
static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, bool is_first_k, size_t remain_n) { | |||
static void kern_4x4_remain( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||
bool is_first_k, size_t remain_n) { | |||
K = div_ceil(K, 16); | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
"7:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||
[k] "+r"(K), [output] "+r"(output) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), | |||
[is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"cc", "memory"); | |||
} | |||
static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_mk4_s8_4x4_pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
//! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | |||
int8_t zerobuff[4][64]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | |||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||
"mk4 matmul with m is not times of 4"); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
megdnn_assert( | |||
ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||
"mk4 matmul with m is not times of 4"); | |||
megdnn_assert( | |||
kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
size_t roundk = round_up(kmax - k0, 16); | |||
size_t out_offset = roundk * 4; | |||
int y = y0; | |||
@@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
prefetch_2x(inptr3); | |||
int K = kmax - k0; | |||
for (; K > 15; K -= 16) { | |||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
out_offset); | |||
transpose_interleave_4x4_4_b( | |||
inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||
output += 64; | |||
} | |||
if (K > 0) { | |||
@@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
inptr1 = zerobuff[1]; | |||
inptr2 = zerobuff[2]; | |||
inptr3 = zerobuff[3]; | |||
transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
out_offset); | |||
transpose_interleave_4x4_4_b( | |||
inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||
output += 64; | |||
} | |||
} | |||
@@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
} | |||
} | |||
static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_mk4_s8_4x4_pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int32_t zerobuff[4]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
const int ICB = (ksize) / 4; | |||
const int ksize4 = round_up<int>(ICB, 4) * 4; | |||
int32_t* outptr = reinterpret_cast<int32_t*>(out); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
megdnn_assert( | |||
kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
int k = k0 / 4; | |||
for (; k + 3 < ICB; k += 4) { | |||
const int32_t* inptr0 = | |||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr1 = | |||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
const int32_t* inptr2 = | |||
@@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
outptr += 4 * 4; | |||
} | |||
if (k < ICB) { | |||
const int32_t* inptr0 = | |||
reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
const int32_t* inptr1 = | |||
reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
const int32_t* inptr2 = | |||
@@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
if (k + 3 >= ICB) { | |||
switch (k + 3 - ICB) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
if (k + 3 >= ICB) { | |||
switch (k + 3 - ICB) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
} // namespace matmul_4x4x16 | |||
} // namespace matmul_mk4_4x4x16 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||
///////////////////////// gemm_s8_4x4 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | |||
void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8_4x4::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
void gemm_s8_4x4::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
@@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
} | |||
} | |||
void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s8_4x4::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
matmul_4x4x16::kern_4x4_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, 4, | |||
std::min<size_t>(N - n, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
@@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | |||
void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert(!transpose, | |||
"the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
void gemm_mk4_s8_4x4::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
megdnn_assert( | |||
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
megdnn_assert(!transpose, | |||
"the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
void gemm_mk4_s8_4x4::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert( | |||
!transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||
matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_mk4_s8_4x4::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||
is_first_k); | |||
matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k); | |||
output += B_INTERLEAVE * 4; | |||
cur_packB += K4; | |||
} | |||
if (n < N) { | |||
matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||
is_first_k, N - n); | |||
matmul_mk4_4x4x16::kern_4x4_remain( | |||
packA, cur_packB, K, output, is_first_k, N - n); | |||
} | |||
packA += K4; | |||
} | |||
} | |||
///////////////////////// gemm_s8_8x8 //////////////////////////////////// | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | |||
void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8_8x8::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
ymax, k0, kmax); | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n( | |||
outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
void gemm_s8_8x8::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s8_8x8::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x8x8::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
matmul_8x8x8::kern_4x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x8x8::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -16,14 +16,14 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||
gemm_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||
gemm_mk4_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||
gemm_s8_8x8); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
@@ -52,8 +52,9 @@ namespace matmul_8x12x4 { | |||
#if 1 | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
} | |||
#else | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
"stp q15, q23, [%[outptr7]]\n" | |||
"str q31, [%[outptr7], #32]\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), | |||
[a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), | |||
[b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), | |||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), | |||
[a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), | |||
[b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), | |||
[is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
[outptr7] "=r"(outptr7) | |||
: | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
"v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "cc", "memory"); | |||
} | |||
#endif | |||
@@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
// | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, int m_remain) { | |||
static void kern_4x12( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int m_remain) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
"4:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | |||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||
[is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), | |||
[LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||
[b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), | |||
[b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
[outptr3] "=r"(outptr3), [x0] "=r"(x0) | |||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0), | |||
[a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), | |||
[b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "=r"(x0) | |||
: | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "memory", "cc"); | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "memory", "cc"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
// | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
static void kern_8x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | |||
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | |||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), | |||
[outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||
[x0] "=r"(x0) | |||
: | |||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||
"cc"); | |||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
// | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 4; | |||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | |||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | |||
@@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"4:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | |||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), | |||
[a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
[m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||
[b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
[outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | |||
: | |||
: "v4", "v5", "v6", "v7", "memory", "cc"); | |||
@@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8_8x12_pack_A_n( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int K = kmax - k0; | |||
//! read 8 * 4 in each row | |||
for (; K > 15; K -= 16) { | |||
interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
interleave_8x4_4_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 4, K); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 4, K); | |||
} | |||
} | |||
for (; y < ymax; y += 4) { | |||
@@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
} | |||
} | |||
static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8_8x12_pack_A_t( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8_8x12_pack_B_n( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8_8x12_pack_B_t( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||
int K = kmax - k0; | |||
//! read 12 * 4 in each row | |||
for (; K > 15; K -= 16) { | |||
interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, inptr8, inptr9, inptr10, | |||
inptr11, outptr); | |||
interleave_12x4_4_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
inptr8, inptr9, inptr10, inptr11, outptr); | |||
} | |||
if (K > 0) { | |||
interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||
outptr, 4, K); | |||
interleave_12( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
inptr8, inptr9, inptr10, inptr11, outptr, 4, K); | |||
} | |||
} | |||
for (; y < ymax; y += 4) { | |||
@@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 { | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x12( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k) { | |||
static void kern_4x12( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
"stp q16, q17, [%[outptr0], #128]\n" | |||
"stp q18, q19, [%[outptr0], #160]\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | |||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||
[is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||
[b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), | |||
[b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||
[outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), | |||
[b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||
: | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
"v17", "v18", "v19", "memory", "cc"); | |||
: "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
"v19", "memory", "cc"); | |||
} | |||
// Overview of register layout: | |||
@@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
static void kern_8x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
K /= 4; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
[n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | |||
[a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | |||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||
[x0] "=r"(x0) | |||
[b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0) | |||
: | |||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||
"cc"); | |||
: "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
// Accumulator | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
bool is_first_k, int n_remain) { | |||
K /= 4; | |||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | |||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | |||
@@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"4:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | |||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||
[LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), | |||
[a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||
[x0] "=r"(x0) | |||
[is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||
[b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0) | |||
: | |||
: "v4", "v5", "v6", "v7", "memory", "cc"); | |||
@@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, | |||
"mk4 matmul with m is not times of 4"); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | |||
"mk4 matmul with k is not times of 4"); | |||
static void gemm_mk4_s8_8x12_pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4"); | |||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4"); | |||
int y = y0; | |||
int start_y = y0 / 4; | |||
for (; y + 7 < ymax; y += 8, start_y += 2) { | |||
@@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
interleave_2x4_4_b(inptr0, inptr1, outptr); | |||
} | |||
} | |||
for (; y + 3 < ymax; y += 4, start_y ++) { | |||
for (; y + 3 < ymax; y += 4, start_y++) { | |||
int K = kmax - k0; | |||
const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | |||
std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | |||
} | |||
} | |||
static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_mk4_s8_8x12_pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
const int ksize = kmax - k0; | |||
const int ksize12 = ksize * 12; | |||
const int ksize4 = ksize * 4; | |||
@@ -12,10 +12,10 @@ | |||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace aarch64; | |||
@@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||
/* ====================== gemm_s8_8x12 ===========================*/ | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | |||
void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8_8x12::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, bool transpose) const { | |||
void gemm_s8_8x12::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
@@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
} | |||
} | |||
void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t N, size_t K, dt_int32* C, size_t LDC, | |||
bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s8_8x12::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
@@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_8x12x4::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4)); | |||
matmul_8x12x4::kern_4x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x12x4::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
/* ====================== gemm_mk4_s8_8x12 ===========================*/ | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | |||
void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||
kmax); | |||
void gemm_mk4_s8_8x12::pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax, bool transpose) const { | |||
megdnn_assert( | |||
!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
} | |||
void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||
void gemm_mk4_s8_8x12::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
megdnn_assert( | |||
!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||
matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int32* C, | |||
size_t LDC, bool is_first_k, const dt_int32*, | |||
dt_int32*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_mk4_s8_8x12::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && | |||
((A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int32) || | |||
(A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
@@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += (B_INTERLEAVE << 2); | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_mk4_8x12x4::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 16; | |||
cur_packB += K4; | |||
} | |||
@@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += (B_INTERLEAVE << 2); | |||
cur_packB += K12; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, std::min<size_t>(N - n, 4)); | |||
matmul_mk4_8x12x4::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 16; | |||
cur_packB += K4; | |||
} | |||
@@ -16,14 +16,14 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||
gemm_s8_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||
gemm_mk4_s8_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY( | |||
dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12); | |||
} // namespace aarch64 | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -34,9 +34,9 @@ namespace matmul_4x4x16 { | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 16; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
// Store back into memory | |||
STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||
"v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
"v31"); | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3", | |||
"v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
"v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_4x4_pack_A_n( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
} | |||
} | |||
static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8x8x16_4x4_pack_B_n( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
inptr0 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
transpose_4x16_1_b_helper( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_inner); | |||
outptr_inner += ksize4; | |||
} | |||
@@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (remain >= 0) { | |||
switch (remain) { | |||
case 7: | |||
inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
inptr0 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k) { | |||
static void kern_8x8( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"bne 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", | |||
"x4", "x5", "x6", "x7", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5", | |||
"x6", "x7", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
#undef STORE_LINE | |||
@@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, | |||
size_t n_remain) { | |||
static void kern_8x4( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, size_t n_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
[outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
[n_remain] "+r"(n_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
[outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
* Accumulator | |||
*/ | |||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, | |||
size_t m_remain) { | |||
static void kern_4x8( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, size_t m_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 2b\n" | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
[is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
[outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
[m_remain] "+r"(m_remain) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
[outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | |||
"memory"); | |||
@@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
size_t n_remain) { | |||
static void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, size_t m_remain, size_t n_remain) { | |||
K /= 8; | |||
const int8_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"3:\n" STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | |||
[K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), | |||
[n_remain] "+r"(n_remain) | |||
[x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", | |||
"cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc", | |||
"memory"); | |||
#undef LOAD_LINE | |||
#undef LOAD_C | |||
@@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C | |||
} | |||
static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_8x8_pack_A_n( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
int K = kmax - k0; | |||
for (; K > 15; K -= 16) { | |||
interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
interleave_8x8_2_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
} | |||
if (K > 0) { | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 8, K); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 8, K); | |||
} | |||
} | |||
for (; y < ymax; y += 4) { | |||
@@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
} | |||
} | |||
static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, | |||
int k0, int kmax) { | |||
static void gemm_s8x8x16_8x8_transpose_pack_A_n( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
@@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
megdnn_assert(0); | |||
} | |||
} | |||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
transpose_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
outptr += ksize8; | |||
} | |||
@@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
} | |||
} | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 4, 4); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 4, 4); | |||
outptr += ksize4; | |||
} | |||
@@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
} | |||
} | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 4, xmax - x); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 4, xmax - x); | |||
} | |||
outptr_base += 8 * 8; | |||
@@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
} | |||
} | |||
static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax) { | |||
static void gemm_s8x8x16_8x8_pack_B_n( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
const int ksize = kmax - k0; | |||
@@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr_interleave); | |||
interleave_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave); | |||
outptr += ksize8; | |||
} | |||
@@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr_interleave, 4, 4); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave, 4, 4); | |||
outptr += ksize4; | |||
} | |||
@@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
if (k + 7 >= kmax) { | |||
switch (k + 7 - kmax) { | |||
case 6: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
inptr3 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
inptr4 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
inptr5 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
inptr6 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr7 = zerobuff; | |||
break; | |||
@@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
outptr_interleave = outptr; | |||
interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr_interleave, 4, xmax - x); | |||
interleave_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr_interleave, 4, xmax - x); | |||
} | |||
outptr_base += 8 * 8; | |||
@@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
} | |||
} | |||
static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
const dt_int8* inptr, int ldin, | |||
int y0, int ymax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_8x8_transpose_pack_B_n( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
int kmax) { | |||
int8_t zerobuff[16]; | |||
std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
constexpr int interleave4 = 32; | |||
@@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
int K = kmax - k0; | |||
for (; K > 7; K -= 8) { | |||
transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
inptr6, inptr7, outptr); | |||
transpose_8x8_1_b( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr); | |||
outptr += interleave8; | |||
} | |||
if (K > 0) { | |||
transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
inptr7, outptr, 8, K); | |||
transpose_8( | |||
inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
outptr, 8, K); | |||
outptr += interleave8; | |||
} | |||
} | |||
@@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
if (y + 3 >= ymax) { | |||
switch (y + 3 - ymax) { | |||
case 2: | |||
inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
inptr1 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
inptr2 = zerobuff; | |||
MEGDNN_FALLTHRU | |||
case 0: | |||
inptr3 = zerobuff; | |||
break; | |||
@@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 { | |||
* Accumulator | |||
*/ | |||
// clang-format on | |||
static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
const int8_t* packB, int K, | |||
int16_t* output, int LDC, | |||
bool is_first_k, | |||
int remain_n) { | |||
static __attribute__((noinline)) void kern_16x12( | |||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int remain_n) { | |||
K /= 4; | |||
const int16_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
"6:\n" STORE_C | |||
"101:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
[remain_n] "+r"(remain_n) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"x8", "x9", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
"memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
@@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
* Accumulator | |||
*/ | |||
// clang-format on | |||
static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
const int8_t* packB, int K, | |||
int16_t* output, int LDC, | |||
bool is_first_k, int remain_n) { | |||
static __attribute__((noinline)) void kern_8x12( | |||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int remain_n) { | |||
K /= 4; | |||
const int16_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
"6:\n" STORE_C | |||
"101:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
[remain_n] "+r"(remain_n) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
@@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
* Accumulator | |||
*/ | |||
// clang-format on | |||
static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||
const int8_t* packB, int K, | |||
int16_t* output, int LDC, | |||
bool is_first_k, int remain_n) { | |||
static __attribute__((noinline)) void kern_4x12( | |||
const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int remain_n) { | |||
K /= 4; | |||
const int16_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
@@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||
"6:\n" STORE_C | |||
"101:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
[is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
[remain_n] "+r"(remain_n) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
"memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||
"x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
} | |||
static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||
const dt_int8* inptr, int ldin, | |||
int m0, int mmax, int k0, int kmax) { | |||
static void gemm_s8x8x16_mk4_16x12_pack_A( | |||
dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||
int kmax) { | |||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_m = 16; | |||
@@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||
} | |||
} | |||
static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int n0, int nmax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_mk4_16x12_pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_n = 12; | |||
@@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 { | |||
*/ | |||
// clang-format on | |||
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool, int remain_n) { | |||
static inline void kern_4x4( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool, | |||
int remain_n) { | |||
K = div_ceil(K, 8); | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
@@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
"7:\n" STORE_C | |||
"101:\n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
[remain_n] "+r"(remain_n) | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk), | |||
[LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
"x8", "x9", "x10", "cc", "memory"); | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
"x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
"memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
@@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||
vst1_s8(outptr + 3 * 8, in0.val[3]); | |||
} | |||
static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||
dt_int8* outptr) { | |||
static inline void interleve_8x4_b( | |||
const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) { | |||
int8x16_t in0 = vld1q_s8(inptr); | |||
int8x16_t in1 = vld1q_s8(inptr2); | |||
int32x4x2_t in_x2 = { | |||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
} | |||
static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | |||
int8x16_t in0 = vld1q_s8(inptr); | |||
int8x16_t in1 = vdupq_n_s8(0); | |||
int32x4x2_t in_x2 = { | |||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
} | |||
static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||
int ldin, int m0, int mmax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_mk4_4x4x8_pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) { | |||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_m = 4; | |||
@@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||
} | |||
} | |||
static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int n0, int nmax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_mk4_4x4x8_pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_n = 4; | |||
@@ -18,7 +18,6 @@ namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul_mk4_8x8x8 { | |||
/** | |||
* Overview of register layout: | |||
* | |||
@@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 { | |||
* | v16 | | v28 | | |||
* | v17 | | v29 | | |||
* | v16 | | v30 | | |||
* | v17 | | v31 | | |||
* | v17 | | v31 | | |||
* +--------+ - - - - +---------------------------------+ | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_8x8( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packB;//packA; | |||
const int8_t* b_ptr = packA;//packB; | |||
const int8_t* a_ptr = packB; // packA; | |||
const int8_t* b_ptr = packA; // packB; | |||
// clang-format off | |||
#define LOAD_C_8 \ | |||
"ld1 {v0.8h}, [x0], #16\n" \ | |||
@@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
// clang-format on | |||
// clang-format on | |||
} | |||
static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_8x8_remain( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packB; | |||
const int8_t* b_ptr = packA; | |||
// clang-format off | |||
// clang-format off | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
"add x1, x0, %x[LDC]\n" | |||
@@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
"cbnz %w[K], 1b\n" | |||
"cmp %w[is_first_k], #1\n" | |||
"beq 2f\n" | |||
"beq 2f\n" | |||
"cmp %x[m_remain], #8 \n" | |||
"beq 8f \n" | |||
"cmp %x[m_remain], #4 \n" | |||
@@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
"zip2 v15.2d, v30.2d, v31.2d \n" | |||
"add v6.8h, v6.8h, v13.8h \n" | |||
"add v7.8h, v7.8h, v15.8h \n" | |||
//save to memory | |||
// save to memory | |||
"cmp %x[m_remain], #8 \n" | |||
"beq 4f \n" | |||
"cmp %x[m_remain], #4 \n" | |||
@@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
"b 1000f \n" | |||
"1000: \n" | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
[K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
[m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
: | |||
[ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
[ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
[ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
[ n_remain ] "+r"(n_remain) | |||
: | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
"v29", "v30", "v31"); | |||
// clang-format on | |||
: "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
"v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
// clang-format on | |||
#undef LOAD_C_8 | |||
#undef STORE_C_8 | |||
} | |||
static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x8( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packB;//packA; | |||
const int8_t* b_ptr = packA;//packB; | |||
const int8_t* a_ptr = packB; // packA; | |||
const int8_t* b_ptr = packA; // packB; | |||
// clang-format off | |||
#define LOAD_C_4 \ | |||
"ld1 {v0.8h}, [x0], #16\n" \ | |||
@@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
#undef LOAD_C_4 | |||
#undef STORE_C_4 | |||
} | |||
static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
int n_remain) { | |||
static void kern_4x8_remain( | |||
const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
bool is_first_k, int m_remain, int n_remain) { | |||
K /= 8; | |||
LDC = LDC * sizeof(int16_t); | |||
const int8_t* a_ptr = packB;//packA; | |||
const int8_t* b_ptr = packA;//packB; | |||
// clang-format off | |||
const int8_t* a_ptr = packB; // packA; | |||
const int8_t* b_ptr = packA; // packB; | |||
// clang-format off | |||
register int16_t* outptr asm("x0") = output; | |||
asm volatile( | |||
@@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
#undef STORE_C_4 | |||
} | |||
//! pack to icxoc | |||
//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | |||
//! if M K is not times of 8,pack 0 instead | |||
static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
const dt_int8* inptr, int ldin, | |||
int m0, int mmax, int k0, int kmax) { | |||
//! if M K is not times of 8,pack 0 instead | |||
static void gemm_s8x8x16_mk4_8x8x8_pack_A( | |||
dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||
int kmax) { | |||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_m = 8; | |||
@@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
int k_idx = k0; | |||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||
} | |||
if (k_idx < kmax) { | |||
@@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
int k_idx = k0; | |||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
inptr1 = zerobuff; | |||
interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||
interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||
} | |||
if (k_idx < kmax) { | |||
@@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
} | |||
//! pack to nxic | |||
//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | |||
static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int n0, int nmax, int k0, | |||
int kmax) { | |||
static void gemm_s8x8x16_mk4_8x8x8_pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_n = 8; | |||
@@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
int8_t tmpbuff0[pack_n * pack_size] = {0}; | |||
int8_t tmpbuff1[pack_n * pack_size] = {0}; | |||
int8_t zerobuff[pack_n * pack_size] = {0}; | |||
const int ksize = round_up<int>((kmax - k0),8); | |||
const int ksize = round_up<int>((kmax - k0), 8); | |||
const int nsize = nmax - n0; | |||
const int n_end = nsize / pack_n * pack_n + n0; | |||
const int remain_n = nsize % pack_n; | |||
int output_stride = ksize * pack_n; | |||
int8_t* outptr_base = out; | |||
int k_idx = k0; | |||
for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
prefetch_3x(inptr0); | |||
@@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
auto outptr = outptr_base; | |||
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | |||
transpose_8x8_mk4_b(inptr0, inptr1, outptr); | |||
outptr += output_stride; | |||
outptr += output_stride; | |||
} | |||
if (remain_n > 0) { | |||
memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | |||
@@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
} | |||
outptr_base += pack_n * pack_k; | |||
} | |||
if(k_idx < kmax){ | |||
if (k_idx < kmax) { | |||
const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
const int8_t* inptr1 = nullptr; | |||
prefetch_3x(inptr0); | |||
@@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
} | |||
} | |||
} // namespace matmul_mk4_16x12x4_a53 | |||
} // namespace matmul_mk4_8x8x8 | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -10,13 +10,13 @@ | |||
* implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/asm/common.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
@@ -28,39 +28,35 @@ using namespace aarch64::matmul; | |||
// ===========================gemm_s8x8x16_4x4================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | |||
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8x8x16_8x8::pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||
ymax, k0, kmax); | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n( | |||
out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8x8x16_8x8::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s8x8x16_8x8::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k); | |||
matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x8x8::kern_8x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
const dt_int8* cur_packB = packB; | |||
size_t n = 0; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
matmul_8x8x8::kern_4x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K8; | |||
} | |||
for (; n < N; n += 4) { | |||
matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), | |||
std::min<size_t>(N - n, 4)); | |||
matmul_8x8x8::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
output += 4; | |||
cur_packB += K4; | |||
} | |||
@@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
// ===========================gemm_s8x8x16_4x4================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | |||
void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8x8x16_4x4::pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
void gemm_s8x8x16_4x4::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool transpose) const { | |||
if (transpose) { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} else { | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
} | |||
void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
(A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
C_dtype.name()); | |||
void gemm_s8x8x16_4x4::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||
C_dtype.enumv() == DTypeEnum::Int16), | |||
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
@@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
matmul_4x4x16::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
B_INTERLEAVE); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
matmul_4x4x16::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
@@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n = 0; | |||
const dt_int8* cur_packB = packB; | |||
for (; n < N; n += B_INTERLEAVE) { | |||
matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
matmul_4x4x16::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, | |||
std::min<size_t>(M - m, A_INTERLEAVE), | |||
std::min<size_t>(N - n, B_INTERLEAVE)); | |||
output += B_INTERLEAVE; | |||
cur_packB += K4; | |||
} | |||
@@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
// ===========================gemm_s8x8x16_mk4_16x12================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | |||
void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||
ymax, k0, kmax); | |||
void gemm_s8x8x16_mk4_16x12_a53::pack_A( | |||
dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A( | |||
out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
void gemm_s8x8x16_mk4_16x12_a53::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
const dt_int8* packB, size_t M, size_t N, | |||
size_t K, dt_int16* C, size_t LDC, | |||
bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
void gemm_s8x8x16_mk4_16x12_a53::kern( | |||
const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
@@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
matmul_mk4_16x12x4_a53::kern_16x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
matmul_mk4_16x12x4_a53::kern_16x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
@@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
matmul_mk4_16x12x4_a53::kern_8x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
matmul_mk4_16x12x4_a53::kern_8x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
@@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
matmul_mk4_16x12x4_a53::kern_4x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
matmul_mk4_16x12x4_a53::kern_4x12( | |||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
@@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
// ===========================gemm_s8x8x16_mk4_4x4_a72================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | |||
void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||
k0, kmax); | |||
void gemm_s8x8x16_mk4_4x4_a72::pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A( | |||
out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||
k0, kmax); | |||
void gemm_s8x8x16_mk4_4x4_a72::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B( | |||
out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, | |||
const dt_int16*, dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
void gemm_s8x8x16_mk4_4x4_a72::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
@@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
const int8_t* cur_packB = packB; | |||
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | |||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
matmul_mk4_4x4x8_a72::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * packed_k; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
matmul_mk4_4x4x8_a72::kern_4x4( | |||
packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * packed_k; | |||
} | |||
@@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
// ===========================gemm_s8x8x16_mk4_8x8x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | |||
void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||
int ldin, int y0, int ymax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||
ymax, k0, kmax); | |||
void gemm_s8x8x16_mk4_8x8x8::pack_A( | |||
dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||
int ldin, int x0, int xmax, int k0, | |||
int kmax, bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||
xmax, k0, kmax); | |||
void gemm_s8x8x16_mk4_8x8x8::pack_B( | |||
dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
void gemm_s8x8x16_mk4_8x8x8::kern( | |||
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
megdnn_assert( | |||
A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
@@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_m, pack_n); | |||
matmul_mk4_8x8x8::kern_8x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += KSIZE8; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_m, remain_n); | |||
matmul_mk4_8x8x8::kern_8x8_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += KSIZE8; | |||
} | |||
@@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, pack_n); | |||
matmul_mk4_8x8x8::kern_4x8( | |||
packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||
is_first_k, 4, remain_n); | |||
matmul_mk4_8x8x8::kern_4x8_remain( | |||
packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||